Unverified Commit ce18bee7 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Load modules during initialize for Norm and Act primitives (#2219)



Load modules during initialize
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarJAX Toolbox <jax@nvidia.com>
parent 7fa0f554
...@@ -41,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D ...@@ -41,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
// Activation // Activation
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x); JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
......
...@@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x) {
...@@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
DActLuDBiasQuantizeInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>; ...@@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary; using Dictionary = xla::ffi::Dictionary;
constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare;
constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);
...@@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) { ...@@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) {
} }
} }
template <typename... Args>
Error_Type wrapInStreamCapture(std::function<Error_Type(cudaStream_t, Args...)> func,
cudaStream_t stream, Args... args) {
cudaGraph_t graph{};
NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed));
Error_Type error = func(stream, std::forward<Args>(args)...);
NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph));
NVTE_CHECK_CUDA(cudaGraphDestroy(graph));
return error;
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, ...@@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf,
Buffer_Type gamma_buf, Buffer_Type beta_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type,
bool zero_centered_gamma, double epsilon, int64_t sm_margin,
JAXX_Scaling_Mode scaling_mode, bool is_2x) {
return wrapInStreamCapture(
std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf,
wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise_output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // colwise_scale_inv
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<double>("epsilon")
.Attr<int64_t>("sm_margin")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
DType w_dtype, NVTE_Norm_Type norm_type, DType w_dtype, NVTE_Norm_Type norm_type,
bool zero_centered_gamma, int sm_margin) { bool zero_centered_gamma, int sm_margin) {
...@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, ...@@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI,
.Attr<int64_t>("sm_margin"), .Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf,
Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, int64_t norm_type,
bool zero_centered_gamma, int64_t sm_margin) {
return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf,
rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf,
norm_type, zero_centered_gamma, sm_margin);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Attr<int64_t>("norm_type")
.Attr<bool>("zero_centered_gamma")
.Attr<int64_t>("sm_margin"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -22,8 +22,12 @@ pybind11::dict Registrations() { ...@@ -22,8 +22,12 @@ pybind11::dict Registrations() {
pybind11::dict dict; pybind11::dict dict;
// Activation // Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] =
dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
// Quantization // Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
...@@ -44,9 +48,11 @@ pybind11::dict Registrations() { ...@@ -44,9 +48,11 @@ pybind11::dict Registrations() {
// Normalization // Normalization
dict["te_norm_forward_ffi"] = dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_norm_backward_ffi"] = dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention // Attention
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment