"src/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "721df4f350b1c9463d496568898c3145d4ec55b3"
Unverified Commit 0e1d9fae authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Bug fix for distributed normalization (#1366)



* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent e4c99b03
......@@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
class NormalizationPlanRegistry {
public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance;
static thread_local NormalizationPlanRegistry instance;
return instance;
}
......
......@@ -147,7 +147,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, output_type),
......@@ -441,7 +441,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
sm_margin = get_backward_sm_margin()
wkspace_aval = ctx.avals_out[-4:]
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -650,7 +650,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
......@@ -841,7 +841,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-3:]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
......@@ -1088,7 +1088,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
......@@ -1394,7 +1394,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
......
......@@ -83,12 +83,24 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
dict["te_layernorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_layernorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
dict["te_rmsnorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));
// Attention
pybind11::dict fused_attn_forward_ffi;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment