Unverified Commit df949037 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for layer norm and RMS norm (#1290)



* Add LayerNormForwardFFI(); add FFI calls in Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FFI for RMS norm, all tests passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Simplify layer & RMS norm FFI calls
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify tensor size calculations
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d7256866
...@@ -13,6 +13,7 @@ from jax import core, dtypes ...@@ -13,6 +13,7 @@ from jax import core, dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType from transformer_engine.transformer_engine_jax import DType as TEDType
...@@ -25,6 +26,7 @@ from .misc import ( ...@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype, jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype, jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
is_ffi_enabled,
) )
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
...@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
# Output shape is same as the input shape, but the output type is same as the weight type. # Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp # See ln_api.cpp
output_type = g_type.element_type output_type = g_type.element_type
...@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, output_type), ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, beta] operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape] operand_shapes = [x_shape, g_shape, b_shape]
...@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
dz_shape = ir.RankedTensorType(dz.type).shape dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape rsigma_shape = ir.RankedTensorType(rsigma.type).shape
...@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive):
""" """
RMSNorm fwd lowering rules RMSNorm fwd lowering rules
""" """
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape x_shape = x_type.shape
...@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type), ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type), ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma] operands = [x, gamma]
operand_shapes = [x_shape, g_shape] operand_shapes = [x_shape, g_shape]
...@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive):
""" """
RMSNorm bwd lowering rules RMSNorm bwd lowering rules
""" """
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in _, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type) x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape x_shape = x_type.shape
...@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive):
out_types = [ out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type), ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type), ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
ir.RankedTensorType.get( ir.RankedTensorType.get(
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
), ),
...@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert g_type == b_type assert g_type == b_type
assert g_shape == b_shape assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get() ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get() ir_rsigma_dtype = ir.F32Type.get()
...@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, beta, amax, scale, scale_inv] operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [ operand_shapes = [
...@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
# Currently only support casting to E4M3 only in C side. # Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
x,
gamma,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
...@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, ir_out_dtype), ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), ir.RankedTensorType.get(
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
] ]
operands = [x, gamma, amax, scale, scale_inv] operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len); size_t opaque_len);
...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler); ...@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization // Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac ...@@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
} }
} }
Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buffer_Type *gamma_buf,
Buffer_Type *beta_buf, Buffer_Type *amax_buf,
Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf,
Result_Type *output_buf, Result_Type *mu_buf,
Result_Type *rsigma_buf, Result_Type *amax_out_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm, bool is_fp8) {
auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *output = (*output_buf)->untyped_data();
auto *rsigma = (*rsigma_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
void *bias = nullptr;
void *mu = nullptr;
if (is_layer_norm) {
bias = beta_buf->untyped_data();
mu = (*mu_buf)->untyped_data();
}
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
void *amax_out = nullptr;
auto out_dtype = in_dtype;
if (is_fp8) {
amax = reinterpret_cast<float *>(amax_buf->untyped_data());
scale = reinterpret_cast<float *>(scale_buf->untyped_data());
scale_inv = reinterpret_cast<float *>(scale_inv_buf->untyped_data());
amax_out = (*amax_out_buf)->untyped_data();
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForward primitive");
out_dtype = DType::kFloat8E4M3;
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf,
&scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf,
nullptr, // amax_buf
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf, &mu_buf, &rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type rsigma_buf, Result_Type amax_out_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
&amax_buf, &scale_buf, &scale_inv_buf, &output_buf,
nullptr, // mu_buf,
&rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf,
zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
nullptr, // amax_buf,
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf,
nullptr, // mu_buf,
&rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
...@@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
} }
} }
Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Buffer_Type *x_buf,
Buffer_Type *mu_buf, Buffer_Type *rsigma_buf,
Buffer_Type *gamma_buf, Result_Type *xgrad_buf,
Result_Type *wgrad_buf, Result_Type *dbeta_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type());
auto *ograd = dz_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *xgrad = (*xgrad_buf)->untyped_data();
auto *wgrad = (*wgrad_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
auto *dgamma_part = (*dgamma_part_buf)->untyped_data();
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
auto dbeta_part_dtype = DType::kByte;
if (is_layer_norm) {
mu = (*mu_buf).untyped_data();
dbeta = (*dbeta_buf)->untyped_data();
dbeta_part = (*dbeta_part_buf)->untyped_data();
dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type());
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
Shape dgamma_part_shape;
auto dgamma_part_dims = (*dgamma_part_buf)->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
Shape dbeta_part_shape;
if (is_layer_norm) {
auto dbeta_part_dims = (*dbeta_part_buf)->dimensions();
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
} else {
dbeta_part_shape.from_vector({0, 0});
}
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf,
&dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_,
sm_margin_,
true // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, Result_Type dgamma_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf,
nullptr, // mu_buf
&rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf,
nullptr, // dbeta_buf,
&wkspace_buf, &barrier_buf, &dgamma_part_buf,
nullptr, // dbeta_part_buf,
zero_centered_gamma, eps_, sm_margin_,
false // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) { size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
...@@ -237,72 +565,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -237,72 +565,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
sm_margin, stream); sm_margin, stream);
} }
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *bias = beta_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *amax_out = amax_out_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
...@@ -376,79 +638,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -376,79 +638,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_dtype, sm_margin, stream); dbeta_part_dtype, sm_margin, stream);
} }
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type());
auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
auto *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
Shape dgamma_part_shape, dbeta_part_shape;
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0]; auto *input = buffers[0];
auto *weight = buffers[1]; auto *weight = buffers[1];
......
...@@ -52,15 +52,29 @@ pybind11::dict Registrations() { ...@@ -52,15 +52,29 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler); dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict; return dict;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment