"...AutoBuildImmortalWrt.git" did not exist on "3e05d02b49abf9c8ebd840eaaed7e18a642360f3"
Unverified Commit 5d34b2ac authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Add SM margin to LayerNorm in inference (#772)



* Add LN margin to inference
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix symbolic func registration
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix grads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b4ef463d
...@@ -660,6 +660,7 @@ def test_export_layernorm( ...@@ -660,6 +660,7 @@ def test_export_layernorm(
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0,
zero_centered_gamma) zero_centered_gamma)
ret = cast_from_fp8( ret = cast_from_fp8(
...@@ -748,6 +749,7 @@ def test_export_rmsnorm( ...@@ -748,6 +749,7 @@ def test_export_rmsnorm(
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0,
zero_centered_gamma) zero_centered_gamma)
ret = cast_from_fp8( ret = cast_from_fp8(
...@@ -1279,6 +1281,7 @@ def test_export_gemm_layernorm( ...@@ -1279,6 +1281,7 @@ def test_export_gemm_layernorm(
self.meta, self.meta,
self.fp8_tensor, self.fp8_tensor,
self.fp8_type, self.fp8_type,
0,
zero_centered_gamma) zero_centered_gamma)
x = cast_from_fp8( x = cast_from_fp8(
......
...@@ -565,6 +565,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -565,6 +565,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
# communication overlap with LN. # communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def _te_forward( def _te_forward(
self, self,
...@@ -600,7 +601,7 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -600,7 +601,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
paddle.is_grad_enabled(), paddle.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
......
...@@ -824,6 +824,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -824,6 +824,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# communication overlap with LN. # communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def _te_forward( def _te_forward(
self, self,
...@@ -865,7 +866,7 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -865,7 +866,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
paddle.is_grad_enabled(), paddle.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
......
...@@ -66,6 +66,7 @@ def layernorm_fwd_fp8_inf( ...@@ -66,6 +66,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int,
zero_centered_gamma, zero_centered_gamma,
) -> torch.Tensor: ) -> torch.Tensor:
"""LayerNorm with FP8 output. """LayerNorm with FP8 output.
...@@ -83,6 +84,7 @@ def layernorm_fwd_fp8_inf( ...@@ -83,6 +84,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor, fp8_tensor,
otype, otype,
sm_margin,
zero_centered_gamma) zero_centered_gamma)
return ret return ret
...@@ -92,6 +94,7 @@ def layernorm_fwd_inf( ...@@ -92,6 +94,7 @@ def layernorm_fwd_inf(
weight: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, bias: torch.Tensor,
eps: float, eps: float,
sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
...@@ -100,6 +103,7 @@ def layernorm_fwd_inf( ...@@ -100,6 +103,7 @@ def layernorm_fwd_inf(
weight, weight,
bias, bias,
eps, eps,
sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
...@@ -149,6 +153,7 @@ def rmsnorm_fwd_fp8_inf( ...@@ -149,6 +153,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int,
zero_centered_gamma, zero_centered_gamma,
) -> torch.Tensor: ) -> torch.Tensor:
"""RMSNorm with FP8 output. """RMSNorm with FP8 output.
...@@ -165,6 +170,7 @@ def rmsnorm_fwd_fp8_inf( ...@@ -165,6 +170,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv, fp8_meta_tensor.scale_inv,
fp8_tensor, fp8_tensor,
otype, otype,
sm_margin,
zero_centered_gamma) zero_centered_gamma)
return ret return ret
...@@ -173,6 +179,7 @@ def rmsnorm_fwd_inf( ...@@ -173,6 +179,7 @@ def rmsnorm_fwd_inf(
inp: torch.Tensor, inp: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float, eps: float,
sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
) -> torch.Tensor: ) -> torch.Tensor:
"""RMSNorm with FP8 output""" """RMSNorm with FP8 output"""
...@@ -180,5 +187,6 @@ def rmsnorm_fwd_inf( ...@@ -180,5 +187,6 @@ def rmsnorm_fwd_inf(
inp, inp,
weight, weight,
eps, eps,
sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
...@@ -408,6 +408,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -408,6 +408,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
...@@ -432,6 +433,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ...@@ -432,6 +433,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps, float eps,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
...@@ -478,6 +480,7 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -478,6 +480,7 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
...@@ -499,6 +502,7 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, ...@@ -499,6 +502,7 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
float eps, float eps,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
......
...@@ -154,12 +154,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -154,12 +154,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference, // This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8( std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0]; return out[0];
} }
...@@ -203,11 +204,13 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ...@@ -203,11 +204,13 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
float eps, float eps,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
// This is a specialized version of layernorm_fwd, optimized for inference, // This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma); std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, sm_margin,
zero_centered_gamma);
return out[0]; return out[0];
} }
...@@ -345,12 +348,13 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -345,12 +348,13 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8( std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0]; return out[0];
} }
...@@ -391,10 +395,11 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, ...@@ -391,10 +395,11 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
float eps, float eps,
const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma
) { ) {
// This is a specialized version of rmsnorm_fwd, optimized for inference, // This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma); std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
return out[0]; return out[0];
} }
...@@ -365,6 +365,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -365,6 +365,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype, int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -377,6 +378,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -377,6 +378,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax, amax,
scale_inv, scale_inv,
otype_arg, otype_arg,
sm_margin,
zero_centered_gamma); zero_centered_gamma);
return output; return output;
...@@ -387,6 +389,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -387,6 +389,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
double eps, double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -394,6 +397,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -394,6 +397,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
weight, weight,
bias, bias,
eps_float, eps_float,
sm_margin,
zero_centered_gamma); zero_centered_gamma);
return output; return output;
...@@ -408,6 +412,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -408,6 +412,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
int64_t fp8_tensor, int64_t fp8_tensor,
int64_t otype, int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype); transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
...@@ -419,6 +424,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -419,6 +424,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax, amax,
scale_inv, scale_inv,
otype_arg, otype_arg,
sm_margin,
zero_centered_gamma); zero_centered_gamma);
return output; return output;
...@@ -428,12 +434,14 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -428,12 +434,14 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
double eps, double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps); float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_inf(input, at::Tensor output = rmsnorm_fwd_inf(input,
weight, weight,
eps_float, eps_float,
sm_margin,
zero_centered_gamma); zero_centered_gamma);
return output; return output;
......
...@@ -78,6 +78,7 @@ def _apply_normalization(inputmat:torch.Tensor, ...@@ -78,6 +78,7 @@ def _apply_normalization(inputmat:torch.Tensor,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
), None, None ), None, None
else: else:
...@@ -88,7 +89,7 @@ def _apply_normalization(inputmat:torch.Tensor, ...@@ -88,7 +89,7 @@ def _apply_normalization(inputmat:torch.Tensor,
) )
else: else:
return normalization_func( return normalization_func(
*inputs, eps, zero_centered_gamma *inputs, eps, fwd_ln_sm_margin, zero_centered_gamma
), None, None ), None, None
if normalization == "RMSNorm": if normalization == "RMSNorm":
output = (ln_out, None, output[1]) output = (ln_out, None, output[1])
......
...@@ -34,6 +34,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -34,6 +34,7 @@ class _LayerNorm(torch.autograd.Function):
eps: float, eps: float,
fwd_ln_sm_margin: int, fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
...@@ -58,7 +59,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -58,7 +59,7 @@ class _LayerNorm(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
else: else:
ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight, ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight,
ln_bias, eps, zero_centered_gamma), None, None ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma), None, None
return ln_out.view_as(inp) return ln_out.view_as(inp)
@staticmethod @staticmethod
...@@ -72,7 +73,7 @@ class _LayerNorm(torch.autograd.Function): ...@@ -72,7 +73,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out, inputmat, mu, rsigma, ln_weight, d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None
class LayerNorm(torch.nn.Module): class LayerNorm(torch.nn.Module):
...@@ -148,6 +149,7 @@ class LayerNorm(torch.nn.Module): ...@@ -148,6 +149,7 @@ class LayerNorm(torch.nn.Module):
# communication overlap with LN. # communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
...@@ -198,6 +200,7 @@ class LayerNorm(torch.nn.Module): ...@@ -198,6 +200,7 @@ class LayerNorm(torch.nn.Module):
self.eps, self.eps,
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.activation_dtype, self.activation_dtype,
......
...@@ -999,6 +999,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -999,6 +999,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# communication overlap with LN. # communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
...@@ -1165,7 +1166,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1165,7 +1166,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output, self.return_layernorm_output,
self.return_layernorm_output_gathered, self.return_layernorm_output_gathered,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.normalization, self.normalization,
......
...@@ -1427,6 +1427,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1427,6 +1427,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# communication overlap with LN. # communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
...@@ -1575,7 +1576,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1575,7 +1576,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
self.set_parallel_mode, self.set_parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fwd_ln_sm_margin, self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation, self.activation,
......
...@@ -31,6 +31,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -31,6 +31,7 @@ class _RMSNorm(torch.autograd.Function):
eps: float, eps: float,
fwd_rmsnorm_sm_margin: int, fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int, bwd_rmsnorm_sm_margin: int,
inf_rmsnorm_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
is_grad_enabled: bool, is_grad_enabled: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
...@@ -55,7 +56,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -55,7 +56,7 @@ class _RMSNorm(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
else: else:
rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight, rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight,
eps, eps, inf_rmsnorm_sm_margin,
zero_centered_gamma) zero_centered_gamma)
return rmsnorm_out.view_as(inp) return rmsnorm_out.view_as(inp)
...@@ -79,6 +80,7 @@ class _RMSNorm(torch.autograd.Function): ...@@ -79,6 +80,7 @@ class _RMSNorm(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -151,6 +153,7 @@ class RMSNorm(torch.nn.Module): ...@@ -151,6 +153,7 @@ class RMSNorm(torch.nn.Module):
# communication overlap with RMSNorm. # communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params""" """Init RMSNorm params"""
...@@ -195,6 +198,7 @@ class RMSNorm(torch.nn.Module): ...@@ -195,6 +198,7 @@ class RMSNorm(torch.nn.Module):
self.eps, self.eps,
self.fwd_rmsnorm_sm_margin, self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin, self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.activation_dtype, self.activation_dtype,
......
...@@ -304,9 +304,9 @@ def _ones_like(g, inp, dtype): ...@@ -304,9 +304,9 @@ def _ones_like(g, inp, dtype):
return one return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma): scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma):
"""ONNX graph for layernorm_fwd_fp8""" """ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs) inp_dtype = get_TensorProtoDataType(inputs)
...@@ -316,13 +316,13 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, ...@@ -316,13 +316,13 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
if inp_dtype != get_TensorProtoDataType(bias): if inp_dtype != get_TensorProtoDataType(bias):
bias = g.op("Cast", bias, to_i=inp_dtype) bias = g.op("Cast", bias, to_i=inp_dtype)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma) ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f", "b") @symbolic_helper.parse_args("v", "v", "v", "f", "i", "b")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma):
"""ONNX graph for layernorm_fwd""" """ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
...@@ -352,9 +352,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -352,9 +352,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
) )
return ln return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma): scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd_fp8""" """ONNX graph for rmsnorm_fwd_fp8"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs) inp_dtype = get_TensorProtoDataType(inputs)
...@@ -362,13 +362,13 @@ def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax, ...@@ -362,13 +362,13 @@ def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
if inp_dtype != get_TensorProtoDataType(weight): if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype) weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma) ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor) fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln return fp8_ln
@symbolic_helper.parse_args("v", "v", "f", "b") @symbolic_helper.parse_args("v", "v", "f", "i", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma): def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd""" """ONNX graph for rmsnorm_fwd"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment