Unverified Commit e5ffaa76 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Added prepare phase for the FusedAttnForwardFFI (#1313)



* added prepare phase for the FusedAttnForwardFFI

* enabled FusedAttnForwardFFI by default

* moved prepare phase into pybind

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 4d65073f
......@@ -47,6 +47,7 @@ include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
pycudnn.cpp
cudnn_utils.cpp
transformer_engine.cpp
common.cu
transpose/cast_transpose.cu
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../fused_attn/utils.h"
#include "transformer_engine/cudnn.h"
namespace transformer_engine {
void nvte_cudnn_handle_init() {
auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle();
}
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cudnn.h
* \brief Helper for cuDNN initialization
*/
#ifndef TRANSFORMER_ENGINE_CUDNN_H_
#define TRANSFORMER_ENGINE_CUDNN_H_
#include "transformer_engine.h"
/*! \namespace transformer_engine
*/
namespace transformer_engine {
/*! \brief TE/JAX cudaGraph requires the cuDNN initialization to happen outside of the capturing
* region. This function is a helper to call cudnnCreate() which allocate memory for the handle.
* The function will be called in the initialize() phase of the related XLA custom calls.
*/
void nvte_cudnn_handle_init();
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_CUDNN_H_
......@@ -379,7 +379,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
wkspace_aval = ctx.avals_out[-1]
if is_ffi_enabled() and bool(os.getenv("NVTE_JAX_FUSED_ATTN_WITH_FFI")):
if is_ffi_enabled():
name = "te_fused_attn_forward_ffi"
out = ffi.ffi_lowering(name)(
ctx,
......
......@@ -289,6 +289,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);
} // namespace jax
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/cudnn.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace transformer_engine {
namespace jax {
Error_Type CudnnHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets,
Dictionary attrs) {
nvte_cudnn_handle_init();
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CudnnHandleInitHandler, CudnnHandleInitFFI,
FFI::Bind<FFI_Prepare>().RemainingArgs().RemainingRets().Attrs());
} // namespace jax
} // namespace transformer_engine
......@@ -16,10 +16,14 @@ namespace jax {
using Buffer_Type = xla::ffi::AnyBuffer;
using Result_Type = xla::ffi::Result<xla::ffi::AnyBuffer>;
using Variadic_Buffer_Type = xla::ffi::RemainingArgs;
using Variadic_Result_Type = xla::ffi::RemainingRets;
using Error_Type = xla::ffi::Error;
using FFI = xla::ffi::Ffi;
using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary;
constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);
......
......@@ -78,7 +78,10 @@ pybind11::dict Registrations() {
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
pybind11::dict fused_attn_forward_ffi;
fused_attn_forward_ffi["prepare"] = EncapsulateFFI(CudnnHandleInitHandler);
fused_attn_forward_ffi["execute"] = EncapsulateFFI(FusedAttnForwardHandler);
dict["te_fused_attn_forward_ffi"] = fused_attn_forward_ffi;
return dict;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment