Unverified Commit 0e1d9fae authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Bug fix for distributed normalization (#1366)



* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e4c99b03
...@@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { ...@@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
class NormalizationPlanRegistry { class NormalizationPlanRegistry {
public: public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() { static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance; static thread_local NormalizationPlanRegistry instance;
return instance; return instance;
} }
......
...@@ -147,7 +147,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -147,7 +147,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:] wkspace_aval = ctx.avals_out[-1]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, output_type), ir.RankedTensorType.get(out_shape, output_type),
...@@ -441,7 +441,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -441,7 +441,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
sm_margin = get_backward_sm_margin() sm_margin = get_backward_sm_margin()
wkspace_aval = ctx.avals_out[-4:] wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -650,7 +650,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -650,7 +650,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:] wkspace_aval = ctx.avals_out[-1]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(out_shape, x_type.element_type),
...@@ -841,7 +841,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -841,7 +841,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape) hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-3:] wkspace_aval = ctx.avals_out[-1]
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(x_shape, x_type.element_type),
...@@ -1088,7 +1088,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1088,7 +1088,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:] wkspace_aval = ctx.avals_out[-1]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
...@@ -1394,7 +1394,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1394,7 +1394,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1] batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:] wkspace_aval = ctx.avals_out[-1]
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
......
...@@ -83,12 +83,24 @@ pybind11::dict Registrations() { ...@@ -83,12 +83,24 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); dict["te_layernorm_forward_ffi"] =
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); dict["te_layernorm_forward_fp8_ffi"] =
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
dict["te_rmsnorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));
// Attention // Attention
pybind11::dict fused_attn_forward_ffi; pybind11::dict fused_attn_forward_ffi;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment