Unverified Commit df949037 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for layer norm and RMS norm (#1290)



* Add LayerNormForwardFFI(); add FFI calls in Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FFI for RMS norm, all tests passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Simplify layer & RMS norm FFI calls
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify tensor size calculations
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d7256866
...@@ -13,6 +13,7 @@ from jax import core, dtypes ...@@ -13,6 +13,7 @@ from jax import core, dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
...@@ -25,6 +26,7 @@ from .misc import ( ...@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
is_ffi_enabled,
) )
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
# Output shape is same as the input shape, but the output type is same as the weight type. # Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp # See ln_api.cpp
output_type = g_type.element_type output_type = g_type.element_type
...@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, output_type), ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, beta] operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape] operand_shapes = [x_shape, g_shape, b_shape]
...@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
dz_shape = ir.RankedTensorType(dz.type).shape dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape rsigma_shape = ir.RankedTensorType(rsigma.type).shape
...@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive):
""" """
RMSNorm fwd lowering rules RMSNorm fwd lowering rules
""" """
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape x_shape = x_type.shape
...@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type), ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma] operands = [x, gamma]
operand_shapes = [x_shape, g_shape] operand_shapes = [x_shape, g_shape]
...@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive):
""" """
RMSNorm bwd lowering rules RMSNorm bwd lowering rules
""" """
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in _, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape x_shape = x_type.shape
...@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type), ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
ir.RankedTensorType.get( ir.RankedTensorType.get(
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
), ),
...@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get() ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get() ir_rsigma_dtype = ir.F32Type.get()
...@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, beta, amax, scale, scale_inv] operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ operand_shapes = [
...@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
# Currently only support casting to E4M3 only in C side. # Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
x,
gamma,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, amax, scale, scale_inv] operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); ...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization // Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -52,15 +52,29 @@ pybind11::dict Registrations() { ...@@ -52,15 +52,29 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict; return dict;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment