Unverified Commit 753eed31 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Support layernorm/rmsnorm sm_margin control through environment variable (#520)



Support layernorm sm_margin through environment variables
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 0fc402fb
......@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import Tuple
from functools import partial, reduce
import operator
import os
import warnings
import numpy as np
......@@ -339,6 +340,8 @@ class LayerNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -346,6 +349,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
......@@ -491,6 +495,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -498,6 +504,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
......@@ -642,6 +649,8 @@ class RmsNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -649,6 +658,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
......@@ -778,6 +788,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -785,6 +797,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
......@@ -3040,6 +3053,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -3047,6 +3062,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma,
epsilon,
sm_margin,
)
out = custom_caller(LayerNormFwdFp8Primitive.name,
......@@ -3244,6 +3260,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
......@@ -3251,6 +3269,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma
epsilon,
sm_margin,
)
out = custom_caller(RmsNormFwdFp8Primitive.name,
......
......@@ -69,9 +69,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType
}
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
bool zero_centered_gamma, float eps) {
bool zero_centered_gamma, float eps, int sm_margin) {
return PackOpaque(
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps});
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin});
}
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
......@@ -282,10 +282,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
desc.use_split_accumulator, 0, stream);
}
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output,
DType out_dtype, void *mu, void *rsigma, float *amax, float *scale,
float *scale_inv, cudaStream_t stream) {
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
void *bias, void *output, DType out_dtype, void *mu, void *rsigma,
float *amax, float *scale, float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
......@@ -302,7 +302,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
// Create uninitialized workspace, barrier and init them on the first
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) {
......@@ -351,9 +351,9 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
}
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta,
cudaStream_t stream) {
int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad,
void *dbeta, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n};
......@@ -376,7 +376,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount();
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
size_t dbeta_part_size{};
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
......@@ -465,11 +465,13 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -492,9 +494,11 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps;
auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -506,6 +510,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0];
auto *mu = buffers[1];
......@@ -516,8 +521,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *wgrad = buffers[6];
auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
}
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -541,10 +546,12 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -566,10 +573,12 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream);
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
}
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......@@ -587,12 +596,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto w_dtype = desc.w_dtype;
auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
void *mu = nullptr;
void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
}
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......
......@@ -16,9 +16,9 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "common/util/logging.h"
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
#include "common/util/logging.h"
namespace transformer_engine {
namespace jax {
......@@ -75,10 +75,11 @@ struct CustomCallNormDescriptor {
DType w_dtype;
bool zero_centered_gamma;
float eps;
int sm_margin;
};
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
bool zero_centered_gamma, float eps);
bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor {
size_t batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment