Unverified Commit 753eed31 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Support layernorm/rmsnorm sm_margin control through environment variable (#520)



Support layernorm sm_margin through environment variables
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 0fc402fb
...@@ -8,6 +8,7 @@ from dataclasses import dataclass ...@@ -8,6 +8,7 @@ from dataclasses import dataclass
from typing import Tuple from typing import Tuple
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import os
import warnings import warnings
import numpy as np import numpy as np
...@@ -339,6 +340,8 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -339,6 +340,8 @@ class LayerNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, b_shape] operand_shapes = [x_shape, g_shape, b_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -346,6 +349,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -346,6 +349,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False) out = custom_caller(LayerNormFwdPrimitive.name, args, opaque, False)
...@@ -491,6 +495,8 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -491,6 +495,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, mu_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -498,6 +504,7 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -498,6 +504,7 @@ class LayerNormBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False) out = custom_caller(LayerNormBwdPrimitive.name, args, opaque, False)
...@@ -642,6 +649,8 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -642,6 +649,8 @@ class RmsNormFwdPrimitive(BasePrimitive):
operand_shapes = [x_shape, g_shape] operand_shapes = [x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -649,6 +658,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -649,6 +658,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False) out = custom_caller(RmsNormFwdPrimitive.name, args, opaque, False)
...@@ -778,6 +788,8 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -778,6 +788,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -785,6 +797,7 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -785,6 +797,7 @@ class RmsNormBwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False) out = custom_caller(RmsNormBwdPrimitive.name, args, opaque, False)
...@@ -3040,6 +3053,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3040,6 +3053,8 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
] ]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -3047,6 +3062,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive): ...@@ -3047,6 +3062,7 @@ class LayerNormFwdFp8Primitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(LayerNormFwdFp8Primitive.name, out = custom_caller(LayerNormFwdFp8Primitive.name,
...@@ -3244,6 +3260,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3244,6 +3260,8 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes) args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
opaque = transformer_engine_jax.pack_norm_descriptor( opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size, batch_size,
hidden_size, hidden_size,
...@@ -3251,6 +3269,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive): ...@@ -3251,6 +3269,7 @@ class RmsNormFwdFp8Primitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype),
False, # RMSNorm doesn't support zero_centered_gamma False, # RMSNorm doesn't support zero_centered_gamma
epsilon, epsilon,
sm_margin,
) )
out = custom_caller(RmsNormFwdFp8Primitive.name, out = custom_caller(RmsNormFwdFp8Primitive.name,
......
...@@ -69,9 +69,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType ...@@ -69,9 +69,9 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, DType
} }
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
bool zero_centered_gamma, float eps) { bool zero_centered_gamma, float eps, int sm_margin) {
return PackOpaque( return PackOpaque(
CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps}); CustomCallNormDescriptor{n, hidden, x_dtype, w_dtype, zero_centered_gamma, eps, sm_margin});
} }
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads, pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch, size_t pad_batch, size_t heads,
...@@ -282,10 +282,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque ...@@ -282,10 +282,10 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
desc.use_split_accumulator, 0, stream); desc.use_split_accumulator, 0, stream);
} }
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input, void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
DType out_dtype, void *mu, void *rsigma, float *amax, float *scale, void *bias, void *output, DType out_dtype, void *mu, void *rsigma,
float *scale_inv, cudaStream_t stream) { float *amax, float *scale, float *scale_inv, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden}; auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden}; auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n}; auto intermediates_shape = std::vector<size_t>{n};
...@@ -302,7 +302,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo ...@@ -302,7 +302,7 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
// Create uninitialized workspace, barrier and init them on the first // Create uninitialized workspace, barrier and init them on the first
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
if (!is_layer_norm) { if (!is_layer_norm) {
...@@ -351,9 +351,9 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo ...@@ -351,9 +351,9 @@ void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, flo
} }
void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps,
void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, int sm_margin, void *input, DType in_dtype, void *weight, DType w_dtype,
void *mu, void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *ograd, void *mu, void *rsigma, void *xgrad, void *wgrad,
cudaStream_t stream) { void *dbeta, cudaStream_t stream) {
auto input_shape = std::vector<size_t>{n, hidden}; auto input_shape = std::vector<size_t>{n, hidden};
auto weight_shape = std::vector<size_t>{hidden}; auto weight_shape = std::vector<size_t>{hidden};
auto intermediates_shape = std::vector<size_t>{n}; auto intermediates_shape = std::vector<size_t>{n};
...@@ -376,7 +376,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl ...@@ -376,7 +376,7 @@ void LayerNormBackwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, fl
TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor; TensorWrapper dummy_workspace_tensor, dummy_barrier_tensor;
TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor;
auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin;
size_t dbeta_part_size{}; size_t dbeta_part_size{};
auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd;
...@@ -465,11 +465,13 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque ...@@ -465,11 +465,13 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream); w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
} }
void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -492,9 +494,11 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s ...@@ -492,9 +494,11 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s
auto eps = desc.eps; auto eps = desc.eps;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream); w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
} }
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -506,6 +510,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -506,6 +510,7 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto *ograd = buffers[0]; auto *ograd = buffers[0];
auto *mu = buffers[1]; auto *mu = buffers[1];
...@@ -516,8 +521,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -516,8 +521,8 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto *wgrad = buffers[6]; auto *wgrad = buffers[6];
auto *dbeta = buffers[7]; auto *dbeta = buffers[7];
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
} }
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -541,10 +546,12 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, ...@@ -541,10 +546,12 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = DType::kFloat8E4M3; auto out_dtype = DType::kFloat8E4M3;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream); w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
} }
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -566,10 +573,12 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz ...@@ -566,10 +573,12 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
auto out_dtype = in_dtype; auto out_dtype = in_dtype;
LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormForwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv, stream); w_dtype, bias, output, out_dtype, mu, rsigma, amax, scale, scale_inv,
stream);
} }
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
...@@ -587,12 +596,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si ...@@ -587,12 +596,13 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si
auto w_dtype = desc.w_dtype; auto w_dtype = desc.w_dtype;
auto eps = desc.eps; auto eps = desc.eps;
auto zero_centered_gamma = desc.zero_centered_gamma; auto zero_centered_gamma = desc.zero_centered_gamma;
auto sm_margin = desc.sm_margin;
void *mu = nullptr; void *mu = nullptr;
void *dbeta = nullptr; void *dbeta = nullptr;
LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, input, in_dtype, weight, w_dtype, LayerNormBackwardImpl(n, hidden, zero_centered_gamma, eps, sm_margin, input, in_dtype, weight,
ograd, mu, rsigma, xgrad, wgrad, dbeta, stream); w_dtype, ograd, mu, rsigma, xgrad, wgrad, dbeta, stream);
} }
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "common/util/logging.h"
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "common/util/logging.h"
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -75,10 +75,11 @@ struct CustomCallNormDescriptor { ...@@ -75,10 +75,11 @@ struct CustomCallNormDescriptor {
DType w_dtype; DType w_dtype;
bool zero_centered_gamma; bool zero_centered_gamma;
float eps; float eps;
int sm_margin;
}; };
pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype, pybind11::bytes PackCustomCallNormDescriptor(size_t n, size_t hidden, DType x_dtype, DType w_dtype,
bool zero_centered_gamma, float eps); bool zero_centered_gamma, float eps, int sm_margin);
struct SoftmaxDescriptor { struct SoftmaxDescriptor {
size_t batch; size_t batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment