Unverified Commit df949037 authored by Hua Huang's avatar Hua Huang Committed by GitHub
Browse files

[TE/JAX] XLA FFI calls for layer norm and RMS norm (#1290)



* Add LayerNormForwardFFI(); add FFI calls in Python
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Add FFI for RMS norm, all tests passed
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* Simplify layer & RMS norm FFI calls
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Simplify tensor size calculations
Signed-off-by: default avatarHua Huang <huah@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHua Huang <huah@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent d7256866
......@@ -13,6 +13,7 @@ from jax import core, dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
from transformer_engine import transformer_engine_jax
from transformer_engine.transformer_engine_jax import DType as TEDType
......@@ -25,6 +26,7 @@ from .misc import (
jax_dtype_to_te_dtype,
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
is_ffi_enabled,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
......@@ -125,6 +127,19 @@ class LayerNormFwdPrimitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
beta,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
output_type = g_type.element_type
......@@ -142,8 +157,12 @@ class LayerNormFwdPrimitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta]
operand_shapes = [x_shape, g_shape, b_shape]
......@@ -418,6 +437,21 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_backward_ffi"
sm_margin = get_backward_sm_margin()
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
mu,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
dz_shape = ir.RankedTensorType(dz.type).shape
mu_shape = ir.RankedTensorType(mu.type).shape
rsigma_shape = ir.RankedTensorType(rsigma.type).shape
......@@ -629,6 +663,19 @@ class RmsNormFwdPrimitive(BasePrimitive):
"""
RMSNorm fwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_forward_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
x,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
......@@ -646,8 +693,12 @@ class RmsNormFwdPrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
ir.RankedTensorType.get(batch_shape, rsigma_element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma]
operand_shapes = [x_shape, g_shape]
......@@ -819,6 +870,21 @@ class RmsNormBwdPrimitive(BasePrimitive):
"""
RMSNorm bwd lowering rules
"""
if is_ffi_enabled():
name = "te_rmsnorm_backward_ffi"
sm_margin = get_backward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name)(
ctx,
dz,
x,
rsigma,
gamma,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
_, x_aval, _, gamma_aval = ctx.avals_in
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
......@@ -835,8 +901,12 @@ class RmsNormBwdPrimitive(BasePrimitive):
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
ir.RankedTensorType.get(
dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)
),
......@@ -1058,6 +1128,22 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
assert g_type == b_type
assert g_shape == b_shape
if is_ffi_enabled():
name = "te_layernorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
out = ffi.ffi_lowering(name, operand_output_aliases={3: 3})(
ctx,
x,
gamma,
beta,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype)
ir_mu_dtype = ir.F32Type.get()
ir_rsigma_dtype = ir.F32Type.get()
......@@ -1079,8 +1165,12 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, beta, amax, scale, scale_inv]
operand_shapes = [
......@@ -1345,6 +1435,22 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
# Currently only support casting to E4M3 only in C side.
assert out_dtype == jnp.float8_e4m3fn
if is_ffi_enabled():
name = "te_rmsnorm_forward_fp8_ffi"
sm_margin = get_forward_sm_margin()
zero_centered_gamma = False # RMSNorm doesn't support zero_centered_gamma
out = ffi.ffi_lowering(name, operand_output_aliases={2: 2})(
ctx,
x,
gamma,
amax,
scale,
scale_inv,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
sm_margin=sm_margin,
)
else:
x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -1376,8 +1482,12 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
ir.RankedTensorType.get(out_shape, ir_out_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(
wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)
),
ir.RankedTensorType.get(
barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)
),
]
operands = [x, gamma, amax, scale, scale_inv]
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
......
......@@ -196,6 +196,8 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len);
......@@ -212,10 +214,16 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);
// Quantization
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
......@@ -91,6 +91,200 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
}
}
Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buffer_Type *gamma_buf,
Buffer_Type *beta_buf, Buffer_Type *amax_buf,
Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf,
Result_Type *output_buf, Result_Type *mu_buf,
Result_Type *rsigma_buf, Result_Type *amax_out_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm, bool is_fp8) {
auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *output = (*output_buf)->untyped_data();
auto *rsigma = (*rsigma_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
void *bias = nullptr;
void *mu = nullptr;
if (is_layer_norm) {
bias = beta_buf->untyped_data();
mu = (*mu_buf)->untyped_data();
}
float *amax = nullptr;
float *scale = nullptr;
float *scale_inv = nullptr;
void *amax_out = nullptr;
auto out_dtype = in_dtype;
if (is_fp8) {
amax = reinterpret_cast<float *>(amax_buf->untyped_data());
scale = reinterpret_cast<float *>(scale_buf->untyped_data());
scale_inv = reinterpret_cast<float *>(scale_inv_buf->untyped_data());
amax_out = (*amax_out_buf)->untyped_data();
NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForward primitive");
out_dtype = DType::kFloat8E4M3;
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf,
&scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf,
Result_Type rsigma_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf,
nullptr, // amax_buf
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf, &mu_buf, &rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
true, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type rsigma_buf, Result_Type amax_out_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
&amax_buf, &scale_buf, &scale_inv_buf, &output_buf,
nullptr, // mu_buf,
&rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf,
zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
true // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Result_Type output_buf, Result_Type rsigma_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf,
nullptr, // beta_buf,
nullptr, // amax_buf,
nullptr, // scale_buf,
nullptr, // scale_inv_buf,
&output_buf,
nullptr, // mu_buf,
&rsigma_buf,
nullptr, // amax_out_buf,
&wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_,
false, // is_layer_norm
false // is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma,
......@@ -199,6 +393,140 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
}
}
Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Buffer_Type *x_buf,
Buffer_Type *mu_buf, Buffer_Type *rsigma_buf,
Buffer_Type *gamma_buf, Result_Type *xgrad_buf,
Result_Type *wgrad_buf, Result_Type *dbeta_buf,
Result_Type *wkspace_buf, Result_Type *barrier_buf,
Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_,
bool is_layer_norm) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type());
auto *ograd = dz_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *input = x_buf->untyped_data();
auto *weight = gamma_buf->untyped_data();
auto *xgrad = (*xgrad_buf)->untyped_data();
auto *wgrad = (*wgrad_buf)->untyped_data();
auto *workspace = (*wkspace_buf)->untyped_data();
auto *barrier = (*barrier_buf)->untyped_data();
auto *dgamma_part = (*dgamma_part_buf)->untyped_data();
void *mu = nullptr;
void *dbeta = nullptr;
void *dbeta_part = nullptr;
auto dbeta_part_dtype = DType::kByte;
if (is_layer_norm) {
mu = (*mu_buf).untyped_data();
dbeta = (*dbeta_buf)->untyped_data();
dbeta_part = (*dbeta_part_buf)->untyped_data();
dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type());
}
auto x_size = product(x_buf->dimensions());
auto gamma_size = product(gamma_buf->dimensions());
auto wkspace_size = product((*wkspace_buf)->dimensions());
auto barrier_size = product((*barrier_buf)->dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
Shape dgamma_part_shape;
auto dgamma_part_dims = (*dgamma_part_buf)->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
Shape dbeta_part_shape;
if (is_layer_norm) {
auto dbeta_part_dims = (*dbeta_part_buf)->dimensions();
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
} else {
dbeta_part_shape.from_vector({0, 0});
}
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf,
&xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf,
&dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_,
sm_margin_,
true // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf,
Result_Type wgrad_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, Result_Type dgamma_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf,
nullptr, // mu_buf
&rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf,
nullptr, // dbeta_buf,
&wkspace_buf, &barrier_buf, &dgamma_part_buf,
nullptr, // dbeta_part_buf,
zero_centered_gamma, eps_, sm_margin_,
false // is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
size_t opaque_len) {
auto *input = buffers[0];
......@@ -237,72 +565,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
sm_margin, stream);
}
Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf,
Buffer_Type beta_buf, Buffer_Type amax_buf, Buffer_Type scale_buf,
Buffer_Type scale_inv_buf, Result_Type output_buf,
Result_Type mu_buf, Result_Type rsigma_buf,
Result_Type amax_out_buf, Result_Type wkspace_buf,
Result_Type barrier_buf, bool zero_centered_gamma, double eps_,
int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *bias = beta_buf.untyped_data();
auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());
auto *output = output_buf->untyped_data();
auto *mu = mu_buf->untyped_data();
auto *rsigma = rsigma_buf->untyped_data();
auto *amax_out = amax_out_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
NVTE_CHECK(amax_out == amax,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive");
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // gamma
.Arg<Buffer_Type>() // beta
.Arg<Buffer_Type>() // amax
.Arg<Buffer_Type>() // scale
.Arg<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // mu
.Ret<Buffer_Type>() // rsigma
.Ret<Buffer_Type>() // amax_out
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
......@@ -376,79 +638,6 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_dtype, sm_margin, stream);
}
Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf,
Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf,
Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf,
Result_Type wkspace_buf, Result_Type barrier_buf,
Result_Type dgamma_part_buf, Result_Type dbeta_part_buf,
bool zero_centered_gamma, double eps_, int64_t sm_margin_) {
auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type());
auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type());
auto wkspace_dtype = convert_ffi_datatype_to_te_dtype(wkspace_buf->element_type());
auto barrier_dtype = convert_ffi_datatype_to_te_dtype(barrier_buf->element_type());
auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype(dgamma_part_buf->element_type());
auto dbeta_part_dtype = convert_ffi_datatype_to_te_dtype(dbeta_part_buf->element_type());
auto *ograd = dz_buf.untyped_data();
auto *mu = mu_buf.untyped_data();
auto *rsigma = rsigma_buf.untyped_data();
auto *input = x_buf.untyped_data();
auto *weight = gamma_buf.untyped_data();
auto *xgrad = xgrad_buf->untyped_data();
auto *wgrad = wgrad_buf->untyped_data();
auto *dbeta = dbeta_buf->untyped_data();
auto *workspace = wkspace_buf->untyped_data();
auto *barrier = barrier_buf->untyped_data();
auto *dgamma_part = dgamma_part_buf->untyped_data();
auto *dbeta_part = dbeta_part_buf->untyped_data();
auto x_size = product(x_buf.dimensions());
auto gamma_size = product(gamma_buf.dimensions());
auto hidden_size = gamma_size;
auto batch_size = x_size / gamma_size;
auto wkspace_size = product(wkspace_buf->dimensions());
auto barrier_size = product(barrier_buf->dimensions());
auto dgamma_part_dims = dgamma_part_buf->dimensions();
auto dbeta_part_dims = dbeta_part_buf->dimensions();
std::vector<size_t> dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end());
std::vector<size_t> dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end());
Shape dgamma_part_shape, dbeta_part_shape;
dgamma_part_shape.from_vector(dgamma_parts_dims_vector);
dbeta_part_shape.from_vector(dbeta_parts_dims_vector);
float eps = static_cast<float>(eps_);
int sm_margin = static_cast<int>(sm_margin_);
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, sm_margin, stream);
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
FFI::Bind()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // dz
.Arg<Buffer_Type>() // x
.Arg<Buffer_Type>() // mu
.Arg<Buffer_Type>() // rsigma
.Arg<Buffer_Type>() // gamma
.Ret<Buffer_Type>() // xgrad
.Ret<Buffer_Type>() // wgrad
.Ret<Buffer_Type>() // dbeta
.Ret<Buffer_Type>() // wkspace
.Ret<Buffer_Type>() // barrier
.Ret<Buffer_Type>() // dgamma_part
.Ret<Buffer_Type>() // dbeta_part
.Attr<bool>("zero_centered_gamma")
.Attr<double>("eps")
.Attr<int64_t>("sm_margin"),
FFI_CudaGraph_Traits);
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
auto *input = buffers[0];
auto *weight = buffers[1];
......
......@@ -52,15 +52,29 @@ pybind11::dict Registrations() {
dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward);
dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward);
// Transpose
dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler);
dict["te_cast_transpose_ffi"] = EncapsulateFFI(CastTransposeHandler);
// Activation
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
// Attention
dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler);
return dict;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment