Unverified Commit 275902fd authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add margin for LayerNorm kernel SM usage (#57)



* Add margin for LayerNorm kernel SM usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert stylistic changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7f270330
...@@ -239,6 +239,7 @@ def layernorm_fwd_fp8( ...@@ -239,6 +239,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
return tex.layernorm_fwd_fp8( return tex.layernorm_fwd_fp8(
...@@ -250,6 +251,7 @@ def layernorm_fwd_fp8( ...@@ -250,6 +251,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor],
otype, otype,
sm_margin,
) )
......
...@@ -373,7 +373,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -373,7 +373,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &mu,
const at::Tensor &rsigma, const at::Tensor &rsigma,
const at::Tensor &gamma const at::Tensor &gamma,
const int sm_margin
) { ) {
auto dx = at::empty_like(x); auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma); auto dgamma = at::empty_like(gamma);
...@@ -393,7 +394,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -393,7 +394,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(), dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Alloc space for Tensors. // Alloc space for Tensors.
...@@ -418,7 +419,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -418,7 +419,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(), dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
return { dx, dgamma, dbeta }; return { dx, dgamma, dbeta };
...@@ -432,7 +433,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -432,7 +433,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
const int sm_margin
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -457,7 +459,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -457,7 +459,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Fill workspace and barrier // Fill workspace and barrier
...@@ -476,7 +478,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -476,7 +478,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
// Actual call to fwd kernel // Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
...@@ -495,7 +497,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -495,7 +497,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
// This is a specialized version of layernorm_fwd_fp8, optimized for inference, // This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8( std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype); input, weight, bias, eps, scale, amax, scale_inv, otype, 0);
return out[0]; return out[0];
} }
...@@ -503,7 +505,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -503,7 +505,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps float eps,
const int sm_margin
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -526,7 +529,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -526,7 +529,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
// This call populates workspace and barrier tensors with the required config // This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Fill workspace and barrier // Fill workspace and barrier
...@@ -545,7 +548,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -545,7 +548,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
// Actual call to fwd kernel // Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
return {ln_out, mu, rsigma}; return {ln_out, mu, rsigma};
...@@ -559,7 +562,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ...@@ -559,7 +562,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
) { ) {
// This is a specialized version of layernorm_fwd, optimized for inference, // This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps); std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0);
return out[0]; return out[0];
} }
......
...@@ -81,7 +81,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, ...@@ -81,7 +81,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &mu,
const at::Tensor &rsigma, const at::Tensor &rsigma,
const at::Tensor &gamma const at::Tensor &gamma,
const int sm_margin
); );
...@@ -92,7 +93,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -92,7 +93,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
const int sm_margin
); );
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
...@@ -108,7 +110,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -108,7 +110,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps float eps,
const int sm_margin
); );
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
......
...@@ -597,7 +597,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -597,7 +597,9 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
return_layernorm_output: bool, return_layernorm_output: bool,
is_training: bool is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -627,6 +629,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -627,6 +629,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin,
) )
else: else:
mu = rsigma = None mu = rsigma = None
...@@ -642,7 +645,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -642,7 +645,7 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
if is_training: if is_training:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
) )
else: else:
ln_out_return, mu, rsigma = layernorm_fwd_inf( ln_out_return, mu, rsigma = layernorm_fwd_inf(
...@@ -657,7 +660,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -657,7 +660,9 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
if is_training: if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps
...@@ -749,6 +754,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -749,6 +754,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -915,7 +921,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -915,7 +921,7 @@ class _LayerNormLinear(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
) )
if not ctx.use_bias: if not ctx.use_bias:
...@@ -942,6 +948,8 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -942,6 +948,8 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -1118,6 +1126,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1118,6 +1126,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
...@@ -1189,6 +1204,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1189,6 +1204,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
self.training, self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -1779,7 +1796,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1779,7 +1796,9 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool, return_layernorm_output: bool,
bias_gelu_nvfusion: bool, bias_gelu_nvfusion: bool,
set_parallel_mode: bool, set_parallel_mode: bool,
is_training: bool is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -1808,6 +1827,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1808,6 +1827,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin,
) )
else: else:
ln_out = layernorm_fwd_fp8_inf( ln_out = layernorm_fwd_fp8_inf(
...@@ -1821,7 +1841,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1821,7 +1841,7 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
ln_out_return, mu, rsigma = tex.layernorm_fwd( ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
) )
ln_out = cast_to_fp8( ln_out = cast_to_fp8(
ln_out_return, ln_out_return,
...@@ -1831,7 +1851,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1831,7 +1851,9 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
if is_training: if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf( ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps inputmat, ln_weight, ln_bias, eps
...@@ -1988,6 +2010,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1988,6 +2010,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
...@@ -2281,7 +2304,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2281,7 +2304,7 @@ class _LayerNormMLP(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
) )
if not ctx.use_bias: if not ctx.use_bias:
...@@ -2313,6 +2336,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2313,6 +2336,8 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -2527,6 +2552,13 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2527,6 +2552,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.size_per_partition, seq_length, micro_batch_size self.size_per_partition, seq_length, micro_batch_size
) )
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
...@@ -2590,6 +2622,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2590,6 +2622,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
self.set_parallel_mode, self.set_parallel_mode,
self.training, self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
...@@ -2618,6 +2652,8 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2618,6 +2652,8 @@ class _LayerNorm(torch.autograd.Function):
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
ln_bias: torch.Tensor, ln_bias: torch.Tensor,
eps: float, eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -2625,9 +2661,10 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2625,9 +2661,10 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible" assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps) ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
return ln_out.view_as(inp) return ln_out.view_as(inp)
@staticmethod @staticmethod
...@@ -2638,9 +2675,9 @@ class _LayerNorm(torch.autograd.Function): ...@@ -2638,9 +2675,9 @@ class _LayerNorm(torch.autograd.Function):
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape) d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
) )
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
...@@ -2695,6 +2732,13 @@ class LayerNorm(torch.nn.Module): ...@@ -2695,6 +2732,13 @@ class LayerNorm(torch.nn.Module):
setattr(self.bias, "sequence_parallel", sequence_parallel) setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict( def load_state_dict(
self, self,
state_dict: Mapping[str, Any], state_dict: Mapping[str, Any],
...@@ -2725,4 +2769,11 @@ class LayerNorm(torch.nn.Module): ...@@ -2725,4 +2769,11 @@ class LayerNorm(torch.nn.Module):
if hasattr(self, "layer_norm_bias"): if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias) setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps) return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment