Unverified Commit df949037 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for layer norm and RMS norm (#1290)



* Add LayerNormForwardFFI(); add FFI calls in Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FFI for RMS norm, all tests passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Simplify layer & RMS norm FFI calls
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify tensor size calculations
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d7256866
......@@ -13,6 +13,7 @@ from jax import core, dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
is_ffi_enabled,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
......@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
output_type = g_type.element_type
......@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
......@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
......@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive):
"""
RMSNorm fwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
......@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
......@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive):
"""
RMSNorm bwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
......@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
ir.RankedTensorType.get(
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
),
......@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
......@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
......@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
x,
gamma,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
......@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
......@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -52,15 +52,29 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment