Unverified Commit df949037 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for layer norm and RMS norm (#1290)



* Add LayerNormForwardFFI(); add FFI calls in Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FFI for RMS norm, all tests passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Simplify layer & RMS norm FFI calls
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify tensor size calculations
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d7256866
...@@ -13,6 +13,7 @@ from jax import core, dtypes ...@@ -13,6 +13,7 @@ from jax import core, dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
...@@ -25,6 +26,7 @@ from .misc import ( ...@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
is_ffi_enabled,
) )
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -125,51 +127,68 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -125,51 +127,68 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
# Output shape is same as the input shape, but the output type is same as the weight type. if is_ffi_enabled():
# See ln_api.cpp name = "te_layernorm_forward_ffi"
output_type = g_type.element_type sm_margin = get_forward_sm_margin()
ir_mu_dtype = ir.F32Type.get() out = ffi.ffi_lowering(name)(
ir_rsigma_dtype = ir.F32Type.get() ctx,
x,
out_shape = x_shape gamma,
hidden_size = reduce(operator.mul, g_shape) beta,
batch_shape = out_shape[:-1] zero_centered_gamma=zero_centered_gamma,
batch_size = reduce(operator.mul, x_shape) // hidden_size eps=epsilon,
sm_margin=sm_margin,
wkspace_aval, barrier_aval = ctx.avals_out[-2:] )
else:
out_types = [ # Output shape is same as the input shape, but the output type is same as the weight type.
ir.RankedTensorType.get(out_shape, output_type), # See ln_api.cpp
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), output_type = g_type.element_type
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir_mu_dtype = ir.F32Type.get()
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir_rsigma_dtype = ir.F32Type.get()
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
] out_shape = x_shape
operands = [x, gamma, beta] hidden_size = reduce(operator.mul, g_shape)
operand_shapes = [x_shape, g_shape, b_shape] batch_shape = out_shape[:-1]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) batch_size = reduce(operator.mul, x_shape) // hidden_size
sm_margin = get_forward_sm_margin() wkspace_aval, barrier_aval = ctx.avals_out[-2:]
opaque = transformer_engine_jax.pack_norm_descriptor( out_types = [
batch_size, ir.RankedTensorType.get(out_shape, output_type),
hidden_size, ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
wkspace_aval.size, ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
barrier_aval.size, ir.RankedTensorType.get(
(0,), # no dgamma_part in FWD pass wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
(0,), # no dbeta_part in BWD pass ),
jax_dtype_to_te_dtype(x_aval.dtype), ir.RankedTensorType.get(
jax_dtype_to_te_dtype(gamma_aval.dtype), barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
jax_dtype_to_te_dtype(wkspace_aval.dtype), ),
jax_dtype_to_te_dtype(barrier_aval.dtype), ]
TEDType.kByte, # dummy dgamma_part te_dtype operands = [x, gamma, beta]
TEDType.kByte, # dummy dbeta_part te_dtype operand_shapes = [x_shape, g_shape, b_shape]
zero_centered_gamma, args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
epsilon,
sm_margin, sm_margin = get_forward_sm_margin()
)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
return out return out
...@@ -418,44 +437,59 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -418,44 +437,59 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
dz_shape = ir.RankedTensorType(dz.type).shape if is_ffi_enabled():
mu_shape = ir.RankedTensorType(mu.type).shape name = "te_layernorm_backward_ffi"
rsigma_shape = ir.RankedTensorType(rsigma.type).shape sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
hidden_size = reduce(operator.mul, g_shape) ctx,
batch_size = reduce(operator.mul, x_shape) // hidden_size dz,
x,
out_types = [ mu,
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) rsigma,
for output in ctx.avals_out gamma,
] zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
operands = [dz, mu, rsigma, x, gamma] sm_margin=sm_margin,
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] )
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) else:
dz_shape = ir.RankedTensorType(dz.type).shape
sm_margin = get_backward_sm_margin() mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
opaque = transformer_engine_jax.pack_norm_descriptor( hidden_size = reduce(operator.mul, g_shape)
batch_size, batch_size = reduce(operator.mul, x_shape) // hidden_size
hidden_size,
wkspace_aval.size, out_types = [
barrier_aval.size, ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
dgamma_part_aval.shape, for output in ctx.avals_out
dbeta_part_aval.shape, ]
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype), operands = [dz, mu, rsigma, x, gamma]
jax_dtype_to_te_dtype(wkspace_aval.dtype), operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
jax_dtype_to_te_dtype(barrier_aval.dtype), args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
jax_dtype_to_te_dtype(dbeta_part_aval.dtype), sm_margin = get_backward_sm_margin()
zero_centered_gamma,
epsilon, wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
sm_margin, opaque = transformer_engine_jax.pack_norm_descriptor(
) batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.shape,
dbeta_part_aval.shape,
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
jax_dtype_to_te_dtype(dbeta_part_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
return out return out
...@@ -629,51 +663,68 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -629,51 +663,68 @@ class RmsNormFwdPrimitive(BasePrimitive):
""" """
RMSNorm fwd lowering rules RMSNorm fwd lowering rules
""" """
x_aval, gamma_aval = ctx.avals_in if is_ffi_enabled():
x_type = ir.RankedTensorType(x.type) name = "te_rmsnorm_forward_ffi"
x_shape = x_type.shape sm_margin = get_forward_sm_margin()
g_type = ir.RankedTensorType(gamma.type) zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
g_shape = g_type.shape out = ffi.ffi_lowering(name)(
rsigma_element_type = ir.F32Type.get() ctx,
x,
out_shape = x_shape gamma,
hidden_size = reduce(operator.mul, g_shape) zero_centered_gamma=zero_centered_gamma,
batch_shape = out_shape[:-1] eps=epsilon,
batch_size = reduce(operator.mul, x_shape) // hidden_size sm_margin=sm_margin,
)
wkspace_aval, barrier_aval = ctx.avals_out[-2:] else:
x_aval, gamma_aval = ctx.avals_in
out_types = [ x_type = ir.RankedTensorType(x.type)
ir.RankedTensorType.get(out_shape, x_type.element_type), x_shape = x_type.shape
ir.RankedTensorType.get(batch_shape, rsigma_element_type), g_type = ir.RankedTensorType(gamma.type)
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), g_shape = g_type.shape
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), rsigma_element_type = ir.F32Type.get()
]
operands = [x, gamma] out_shape = x_shape
operand_shapes = [x_shape, g_shape] hidden_size = reduce(operator.mul, g_shape)
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
sm_margin = get_forward_sm_margin()
wkspace_aval, barrier_aval = ctx.avals_out[-2:]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, out_types = [
hidden_size, ir.RankedTensorType.get(out_shape, x_type.element_type),
wkspace_aval.size, ir.RankedTensorType.get(batch_shape, rsigma_element_type),
barrier_aval.size, ir.RankedTensorType.get(
(0,), # no dgamma_part in FWD pass wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
(0,), # no dbeta_part in BWD pass ),
jax_dtype_to_te_dtype(x_aval.dtype), ir.RankedTensorType.get(
jax_dtype_to_te_dtype(gamma_aval.dtype), barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
jax_dtype_to_te_dtype(wkspace_aval.dtype), ),
jax_dtype_to_te_dtype(barrier_aval.dtype), ]
TEDType.kByte, # dummy dgamma_part te_dtype operands = [x, gamma]
TEDType.kByte, # dummy dbeta_part te_dtype operand_shapes = [x_shape, g_shape]
False, # RMSNorm doesn't support zero_centered_gamma args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
epsilon,
sm_margin, sm_margin = get_forward_sm_margin()
)
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
return out return out
...@@ -819,53 +870,72 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -819,53 +870,72 @@ class RmsNormBwdPrimitive(BasePrimitive):
""" """
RMSNorm bwd lowering rules RMSNorm bwd lowering rules
""" """
_, x_aval, _, gamma_aval = ctx.avals_in if is_ffi_enabled():
x_type = ir.RankedTensorType(x.type) name = "te_rmsnorm_backward_ffi"
x_shape = x_type.shape sm_margin = get_backward_sm_margin()
g_type = ir.RankedTensorType(gamma.type) zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
g_shape = g_type.shape out = ffi.ffi_lowering(name)(
dz_shape = ir.RankedTensorType(dz.type).shape ctx,
rsigma_shape = ir.RankedTensorType(rsigma.type).shape dz,
x,
hidden_size = reduce(operator.mul, g_shape) rsigma,
batch_size = reduce(operator.mul, x_shape) // hidden_size gamma,
zero_centered_gamma=zero_centered_gamma,
wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] eps=epsilon,
sm_margin=sm_margin,
out_types = [ )
ir.RankedTensorType.get(x_shape, x_type.element_type), else:
ir.RankedTensorType.get(g_shape, g_type.element_type), _, x_aval, _, gamma_aval = ctx.avals_in
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), x_type = ir.RankedTensorType(x.type)
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), x_shape = x_type.shape
ir.RankedTensorType.get( g_type = ir.RankedTensorType(gamma.type)
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) g_shape = g_type.shape
), dz_shape = ir.RankedTensorType(dz.type).shape
] rsigma_shape = ir.RankedTensorType(rsigma.type).shape
operands = [dz, rsigma, x, gamma]
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] hidden_size = reduce(operator.mul, g_shape)
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) batch_size = reduce(operator.mul, x_shape) // hidden_size
sm_margin = get_backward_sm_margin() wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:]
opaque = transformer_engine_jax.pack_norm_descriptor( out_types = [
batch_size, ir.RankedTensorType.get(x_shape, x_type.element_type),
hidden_size, ir.RankedTensorType.get(g_shape, g_type.element_type),
wkspace_aval.size, ir.RankedTensorType.get(
barrier_aval.size, wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
dgamma_part_aval.shape, ),
(0,), # no dbeta_part for RMSnorm ir.RankedTensorType.get(
jax_dtype_to_te_dtype(x_aval.dtype), barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
jax_dtype_to_te_dtype(gamma_aval.dtype), ),
jax_dtype_to_te_dtype(wkspace_aval.dtype), ir.RankedTensorType.get(
jax_dtype_to_te_dtype(barrier_aval.dtype), dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
jax_dtype_to_te_dtype(dgamma_part_aval.dtype), ),
TEDType.kByte, # dummy dbeta_part te_dtype ]
False, # RMSNorm doesn't support zero_centered_gamma operands = [dz, rsigma, x, gamma]
epsilon, operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
sm_margin, args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
)
sm_margin = get_backward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
dgamma_part_aval.shape,
(0,), # no dbeta_part for RMSnorm
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
jax_dtype_to_te_dtype(dgamma_part_aval.dtype),
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
return out return out
...@@ -1058,64 +1128,84 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1058,64 +1128,84 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) if is_ffi_enabled():
ir_mu_dtype = ir.F32Type.get() name = "te_layernorm_forward_fp8_ffi"
ir_rsigma_dtype = ir.F32Type.get() sm_margin = get_forward_sm_margin()
ir_amax_type = ir.RankedTensorType(amax.type) out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ir_amax_dtype = ir_amax_type.element_type ctx,
ir_amax_shape = ir_amax_type.shape x,
ir_scale_shape = ir_amax_shape gamma,
ir_scale_inv_shape = ir_amax_shape beta,
amax,
out_shape = x_shape scale,
hidden_size = reduce(operator.mul, g_shape) scale_inv,
batch_shape = out_shape[:-1] zero_centered_gamma=zero_centered_gamma,
batch_size = reduce(operator.mul, x_shape) // hidden_size eps=epsilon,
sm_margin=sm_margin,
wkspace_aval, barrier_aval = ctx.avals_out[-2:] )
else:
out_types = [ ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir_mu_dtype = ir.F32Type.get()
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir_rsigma_dtype = ir.F32Type.get()
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir_amax_type = ir.RankedTensorType(amax.type)
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir_amax_dtype = ir_amax_type.element_type
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir_amax_shape = ir_amax_type.shape
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ir_scale_shape = ir_amax_shape
] ir_scale_inv_shape = ir_amax_shape
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ out_shape = x_shape
x_shape, hidden_size = reduce(operator.mul, g_shape)
g_shape, batch_shape = out_shape[:-1]
b_shape, batch_size = reduce(operator.mul, x_shape) // hidden_size
ir_amax_shape,
ir_scale_shape, wkspace_aval, barrier_aval = ctx.avals_out[-2:]
ir_scale_inv_shape,
] out_types = [
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
sm_margin = get_forward_sm_margin() ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
opaque = transformer_engine_jax.pack_norm_descriptor( ir.RankedTensorType.get(
batch_size, wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
hidden_size, ),
wkspace_aval.size, ir.RankedTensorType.get(
barrier_aval.size, barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
(0,), # no dgamma_part in FWD pass ),
(0,), # no dbeta_part in BWD pass ]
jax_dtype_to_te_dtype(x_aval.dtype), operands = [x, gamma, beta, amax, scale, scale_inv]
jax_dtype_to_te_dtype(gamma_aval.dtype), operand_shapes = [
jax_dtype_to_te_dtype(wkspace_aval.dtype), x_shape,
jax_dtype_to_te_dtype(barrier_aval.dtype), g_shape,
TEDType.kByte, # dummy dgamma_part te_dtype b_shape,
TEDType.kByte, # dummy dbeta_part te_dtype ir_amax_shape,
zero_centered_gamma, ir_scale_shape,
epsilon, ir_scale_inv_shape,
sm_margin, ]
) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller( out = custom_caller(
LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3} LayerNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={3: 3}
) )
return out return out
...@@ -1345,67 +1435,87 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1345,67 +1435,87 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
# Currently only support casting to E4M3 only in C side. # Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn assert out_dtype == jnp.float8_e4m3fn
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] sm_margin = get_forward_sm_margin()
assert amax_aval.dtype == jnp.float32 zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
assert scale_aval.dtype == jnp.float32 out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
assert scale_inv_aval.dtype == jnp.float32 ctx,
x,
x_type = ir.RankedTensorType(x.type) gamma,
x_shape = x_type.shape amax,
g_type = ir.RankedTensorType(gamma.type) scale,
g_shape = g_type.shape scale_inv,
zero_centered_gamma=zero_centered_gamma,
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) eps=epsilon,
ir_rsigma_dtype = ir.F32Type.get() sm_margin=sm_margin,
ir_amax_type = ir.RankedTensorType(amax.type) )
ir_amax_dtype = ir_amax_type.element_type else:
ir_amax_shape = ir_amax_type.shape x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
ir_scale_shape = ir_amax_shape
ir_scale_inv_shape = ir_amax_shape assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert amax_aval.dtype == jnp.float32
out_shape = x_shape assert scale_aval.dtype == jnp.float32
hidden_size = reduce(operator.mul, g_shape) assert scale_inv_aval.dtype == jnp.float32
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
wkspace_aval, barrier_aval = ctx.avals_out[-2:] g_type = ir.RankedTensorType(gamma.type)
g_shape = g_type.shape
out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir_rsigma_dtype = ir.F32Type.get()
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir_amax_type = ir.RankedTensorType(amax.type)
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir_amax_dtype = ir_amax_type.element_type
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), ir_amax_shape = ir_amax_type.shape
] ir_scale_shape = ir_amax_shape
operands = [x, gamma, amax, scale, scale_inv] ir_scale_inv_shape = ir_amax_shape
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) out_shape = x_shape
hidden_size = reduce(operator.mul, g_shape)
sm_margin = get_forward_sm_margin() batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, wkspace_aval, barrier_aval = ctx.avals_out[-2:]
hidden_size,
wkspace_aval.size, out_types = [
barrier_aval.size, ir.RankedTensorType.get(out_shape, ir_out_dtype),
(0,), # no dgamma_part in FWD pass ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
(0,), # no dbeta_part in BWD pass ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
jax_dtype_to_te_dtype(x_aval.dtype), ir.RankedTensorType.get(
jax_dtype_to_te_dtype(gamma_aval.dtype), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
jax_dtype_to_te_dtype(wkspace_aval.dtype), ),
jax_dtype_to_te_dtype(barrier_aval.dtype), ir.RankedTensorType.get(
TEDType.kByte, # dummy dgamma_part te_dtype barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
TEDType.kByte, # dummy dbeta_part te_dtype ),
False, # RMSNorm doesn't support zero_centered_gamma ]
epsilon, operands = [x, gamma, amax, scale, scale_inv]
sm_margin, operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
wkspace_aval.size,
barrier_aval.size,
(0,), # no dgamma_part in FWD pass
(0,), # no dbeta_part in BWD pass
jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(gamma_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
jax_dtype_to_te_dtype(barrier_aval.dtype),
TEDType.kByte, # dummy dgamma_part te_dtype
TEDType.kByte, # dummy dbeta_part te_dtype
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller( out = custom_caller(
RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2} RmsNormFwdFp8Primitive.name, args, opaque, False, operand_output_aliases={2: 2}
) )
return out return out
......
...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); ...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization // Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac ...@@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
} }
} }
Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buffer_Type *gamma_buf,
Buffer_Type *beta_buf, Buffer_Type *amax_buf,
Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf,
Result_Type *output_buf, Result_Type *mu_buf,
Result_Type *rsigma_buf, Result_Type *amax_out_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm, bool is_fp8) {
auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *output = (*output_buf)->untyped_data();
auto *rsigma = (*rsigma_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
void *bias = nullptr;
void *mu = nullptr;
if (is_layer_norm) {
bias = beta_buf->untyped_data();
mu = (*mu_buf)->untyped_data();
}
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
void *amax_out = nullptr;
auto out_dtype = in_dtype;
if (is_fp8) {
amax = reinterpret_cast<float *>(amax_buf->untyped_data());
scale = reinterpret_cast<float *>(scale_buf->untyped_data());
scale_inv = reinterpret_cast<float *>(scale_inv_buf->untyped_data());
amax_out = (*amax_out_buf)->untyped_data();
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForward primitive");
out_dtype = DType::kFloat8E4M3;
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf,
&scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf,
nullptr, // amax_buf
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf, &mu_buf, &rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type rsigma_buf, Result_Type amax_out_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
&amax_buf, &scale_buf, &scale_inv_buf, &output_buf,
nullptr, // mu_buf,
&rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf,
zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
nullptr, // amax_buf,
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf,
nullptr, // mu_buf,
&rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
...@@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
} }
} }
Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Buffer_Type *x_buf,
Buffer_Type *mu_buf, Buffer_Type *rsigma_buf,
Buffer_Type *gamma_buf, Result_Type *xgrad_buf,
Result_Type *wgrad_buf, Result_Type *dbeta_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type());
auto *ograd = dz_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *xgrad = (*xgrad_buf)->untyped_data();
auto *wgrad = (*wgrad_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
auto *dgamma_part = (*dgamma_part_buf)->untyped_data();
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
auto dbeta_part_dtype = DType::kByte;
if (is_layer_norm) {
mu = (*mu_buf).untyped_data();
dbeta = (*dbeta_buf)->untyped_data();
dbeta_part = (*dbeta_part_buf)->untyped_data();
dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type());
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
Shape dgamma_part_shape;
auto dgamma_part_dims = (*dgamma_part_buf)->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
Shape dbeta_part_shape;
if (is_layer_norm) {
auto dbeta_part_dims = (*dbeta_part_buf)->dimensions();
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
} else {
dbeta_part_shape.from_vector({0, 0});
}
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf,
&dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_,
sm_margin_,
true // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, Result_Type dgamma_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf,
nullptr, // mu_buf
&rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf,
nullptr, // dbeta_buf,
&wkspace_buf, &barrier_buf, &dgamma_part_buf,
nullptr, // dbeta_part_buf,
zero_centered_gamma, eps_, sm_margin_,
false // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
...@@ -237,72 +565,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -237,72 +565,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
sm_margin, stream); sm_margin, stream);
} }
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *bias = beta_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *amax_out = amax_out_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
...@@ -376,79 +638,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -376,79 +638,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_dtype, sm_margin, stream); dbeta_part_dtype, sm_margin, stream);
} }
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type());
auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
auto *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
Shape dgamma_part_shape, dbeta_part_shape;
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
......
...@@ -52,15 +52,29 @@ pybind11::dict Registrations() { ...@@ -52,15 +52,29 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict; return dict;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment