Unverified Commit ba0fe9a7 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Propagate sm_margin to the underly layernorm kernels (#1089)



* Propagate sm_margin to the underly layernorm kernels

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
parent 0075a46a
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for normalization""" """JAX/TE custom ops for normalization"""
from functools import partial, reduce from functools import partial, reduce, cache
import operator import operator
import os import os
import warnings import warnings
...@@ -40,6 +40,18 @@ __all__ = [ ...@@ -40,6 +40,18 @@ __all__ = [
] ]
@cache
def get_forward_sm_margin():
"""Retrieves the number of stream multiprocessors (SM) reserved for other kernels"""
return int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
@cache
def get_backward_sm_margin():
"""Retrieves the number of stream multiprocessors (SM) reserved for other kernels"""
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
class LayerNormFwdPrimitive(BasePrimitive): class LayerNormFwdPrimitive(BasePrimitive):
""" """
Layer Normalization Forward Primitive Layer Normalization Forward Primitive
...@@ -77,6 +89,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -77,6 +89,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
True, True,
kwargs["zero_centered_gamma"], kwargs["zero_centered_gamma"],
kwargs["epsilon"], kwargs["epsilon"],
get_forward_sm_margin(),
) )
wkspace_aval = out_aval.update( wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -136,7 +149,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -136,7 +149,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, b_shape] operand_shapes = [x_shape, g_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
...@@ -354,6 +367,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -354,6 +367,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
True, True,
kwargs["zero_centered_gamma"], kwargs["zero_centered_gamma"],
kwargs["epsilon"], kwargs["epsilon"],
get_backward_sm_margin(),
) )
) )
wkspace_aval = dx_aval.update( wkspace_aval = dx_aval.update(
...@@ -420,7 +434,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -420,7 +434,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_backward_sm_margin()
wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:]
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
...@@ -591,6 +605,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -591,6 +605,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
False, False,
False, False,
kwargs["epsilon"], kwargs["epsilon"],
get_forward_sm_margin(),
) )
wkspace_aval = out_aval.update( wkspace_aval = out_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
...@@ -638,7 +653,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -638,7 +653,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape] operand_shapes = [x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
...@@ -776,6 +791,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -776,6 +791,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
False, False,
False, False,
kwargs["epsilon"], kwargs["epsilon"],
get_backward_sm_margin(),
) )
) )
wkspace_aval = dx_aval.update( wkspace_aval = dx_aval.update(
...@@ -829,7 +845,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -829,7 +845,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_backward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
...@@ -989,6 +1005,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -989,6 +1005,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
True, True,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
get_forward_sm_margin(),
) )
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
...@@ -1076,7 +1093,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -1076,7 +1093,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
...@@ -1296,6 +1313,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1296,6 +1313,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
False, False,
False, False,
epsilon, epsilon,
get_forward_sm_margin(),
) )
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
...@@ -1365,7 +1383,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -1365,7 +1383,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) sm_margin = get_forward_sm_margin()
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
......
...@@ -186,7 +186,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o ...@@ -186,7 +186,7 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
float eps); float eps, int sm_margin);
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
...@@ -196,7 +196,7 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -196,7 +196,7 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
float eps); float eps, int sm_margin);
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
......
...@@ -13,7 +13,7 @@ namespace jax { ...@@ -13,7 +13,7 @@ namespace jax {
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType out_dtype, DType in_dtype, DType w_dtype, DType out_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
float eps) { float eps, int sm_margin) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -26,7 +26,7 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd ...@@ -26,7 +26,7 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
// dummy tensor wrappers that will carry workspace size info later // dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor; TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (is_layer_norm) { if (is_layer_norm) {
auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype);
...@@ -53,7 +53,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac ...@@ -53,7 +53,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *workspace, DType work_dtype, void *barrier, DType out_dtype, void *workspace, DType work_dtype, void *barrier,
DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) { float *scale_inv, int sm_margin, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -70,7 +70,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac ...@@ -70,7 +70,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, input_shape, out_dtype, amax, scale, scale_inv);
auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype);
...@@ -94,7 +94,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac ...@@ -94,7 +94,7 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType w_dtype, DType in_dtype, DType w_dtype,
bool is_layer_norm, bool zero_centered_gamma, bool is_layer_norm, bool zero_centered_gamma,
float eps) { float eps, int sm_margin) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
auto intermediates_shape = std::vector<size_t>{batch_size}; auto intermediates_shape = std::vector<size_t>{batch_size};
...@@ -111,7 +111,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid ...@@ -111,7 +111,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
// dummy tensor wrappers that will carry workspace size info later // dummy tensor wrappers that will carry workspace size info later
TensorWrapper dummy_work_tensor, dummy_barrier_tensor; TensorWrapper dummy_work_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
// initialize dBeta information here -- layernorm will modify but RMSnorm will not // initialize dBeta information here -- layernorm will modify but RMSnorm will not
...@@ -151,7 +151,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -151,7 +151,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
void *weight, DType w_dtype, void *ograd, void *workspace, void *weight, DType w_dtype, void *ograd, void *workspace,
DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu,
void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part,
DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin,
cudaStream_t stream) { cudaStream_t stream) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size}; auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto weight_shape = std::vector<size_t>{hidden_size}; auto weight_shape = std::vector<size_t>{hidden_size};
...@@ -173,7 +173,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace ...@@ -173,7 +173,7 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype); auto xgrad_tensor = TensorWrapper(xgrad, input_shape, x_dtype);
auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype);
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
auto workspace_shape = std::vector<size_t>{wkspace_size}; auto workspace_shape = std::vector<size_t>{wkspace_size};
...@@ -227,13 +227,14 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -227,13 +227,14 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream); sm_margin, stream);
} }
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -262,11 +263,12 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -262,11 +263,12 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream); sm_margin, stream);
} }
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -286,6 +288,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -286,6 +288,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto dbeta_part_dtype = desc.dbeta_part_dtype; auto dbeta_part_dtype = desc.dbeta_part_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0]; auto *ograd = buffers[0];
auto *mu = buffers[1]; auto *mu = buffers[1];
...@@ -304,7 +307,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -304,7 +307,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream); dbeta_part_dtype, sm_margin, stream);
} }
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -334,12 +337,13 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -334,12 +337,13 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream); sm_margin, stream);
} }
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -367,12 +371,13 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -367,12 +371,13 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto barrier_dtype = desc.barrier_dtype; auto barrier_dtype = desc.barrier_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma,
eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace,
wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv,
stream); sm_margin, stream);
} }
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -406,12 +411,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -406,12 +411,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto dbeta_part_dtype = DType::kByte; auto dbeta_part_dtype = DType::kByte;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape,
dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight,
w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu,
rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part,
dbeta_part_dtype, stream); dbeta_part_dtype, sm_margin, stream);
} }
} // namespace jax } // namespace jax
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment