"tests/pytorch/debug/conftest.py" did not exist on "1d903f5e6d5d36eef5f44bbffdc7719b703637e1"
Unverified Commit 5d34b2ac authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Add SM margin to LayerNorm in inference (#772)



* Add LN margin to inference
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix symbolic func registration
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix grads
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b4ef463d
......@@ -660,6 +660,7 @@ def test_export_layernorm(
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
ret = cast_from_fp8(
......@@ -748,6 +749,7 @@ def test_export_rmsnorm(
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
ret = cast_from_fp8(
......@@ -1279,6 +1281,7 @@ def test_export_gemm_layernorm(
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)
x = cast_from_fp8(
......
......@@ -565,6 +565,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(
self,
......@@ -600,7 +601,7 @@ class LayerNormLinear(TransformerEngineBaseLayer):
self.activation_dtype,
self.return_layernorm_output,
paddle.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
......
......@@ -824,6 +824,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def _te_forward(
self,
......@@ -865,7 +866,7 @@ class LayerNormMLP(TransformerEngineBaseLayer):
self.activation_dtype,
self.return_layernorm_output,
paddle.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
......
......@@ -66,6 +66,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
......@@ -83,6 +84,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
sm_margin,
zero_centered_gamma)
return ret
......@@ -92,6 +94,7 @@ def layernorm_fwd_inf(
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
......@@ -100,6 +103,7 @@ def layernorm_fwd_inf(
weight,
bias,
eps,
sm_margin,
zero_centered_gamma,
)
......@@ -149,6 +153,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
) -> torch.Tensor:
"""RMSNorm with FP8 output.
......@@ -165,6 +170,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
sm_margin,
zero_centered_gamma)
return ret
......@@ -173,6 +179,7 @@ def rmsnorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""RMSNorm with FP8 output"""
......@@ -180,5 +187,6 @@ def rmsnorm_fwd_inf(
inp,
weight,
eps,
sm_margin,
zero_centered_gamma,
)
......@@ -408,6 +408,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
......@@ -432,6 +433,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
......@@ -478,6 +480,7 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
......@@ -499,6 +502,7 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
......
......@@ -154,12 +154,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0];
}
......@@ -203,11 +204,13 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, sm_margin,
zero_centered_gamma);
return out[0];
}
......@@ -345,12 +348,13 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0];
}
......@@ -391,10 +395,11 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma);
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
return out[0];
}
......@@ -365,6 +365,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
......@@ -377,6 +378,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);
return output;
......@@ -387,6 +389,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
......@@ -394,6 +397,7 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
weight,
bias,
eps_float,
sm_margin,
zero_centered_gamma);
return output;
......@@ -408,6 +412,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
......@@ -419,6 +424,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);
return output;
......@@ -428,12 +434,14 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_inf(input,
weight,
eps_float,
sm_margin,
zero_centered_gamma);
return output;
......
......@@ -78,6 +78,7 @@ def _apply_normalization(inputmat:torch.Tensor,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
), None, None
else:
......@@ -88,7 +89,7 @@ def _apply_normalization(inputmat:torch.Tensor,
)
else:
return normalization_func(
*inputs, eps, zero_centered_gamma
*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma
), None, None
if normalization == "RMSNorm":
output = (ln_out, None, output[1])
......
......@@ -34,6 +34,7 @@ class _LayerNorm(torch.autograd.Function):
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
......@@ -58,7 +59,7 @@ class _LayerNorm(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight,
ln_bias, eps, zero_centered_gamma), None, None
ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma), None, None
return ln_out.view_as(inp)
@staticmethod
......@@ -72,7 +73,7 @@ class _LayerNorm(torch.autograd.Function):
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None
class LayerNorm(torch.nn.Module):
......@@ -148,6 +149,7 @@ class LayerNorm(torch.nn.Module):
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
......@@ -198,6 +200,7 @@ class LayerNorm(torch.nn.Module):
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
......
......@@ -999,6 +999,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
......@@ -1165,7 +1166,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
......
......@@ -1427,6 +1427,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
......@@ -1575,7 +1576,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
......
......@@ -31,6 +31,7 @@ class _RMSNorm(torch.autograd.Function):
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
inf_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
......@@ -55,7 +56,7 @@ class _RMSNorm(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma
else:
rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight,
eps,
eps, inf_rmsnorm_sm_margin,
zero_centered_gamma)
return rmsnorm_out.view_as(inp)
......@@ -79,6 +80,7 @@ class _RMSNorm(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -151,6 +153,7 @@ class RMSNorm(torch.nn.Module):
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_rmsnorm_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))
def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params"""
......@@ -195,6 +198,7 @@ class RMSNorm(torch.nn.Module):
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.inf_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
......
......@@ -304,9 +304,9 @@ def _ones_like(g, inp, dtype):
return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma):
"""ONNX graph for layernorm_fwd_fp8"""
# pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
......@@ -316,13 +316,13 @@ def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
if inp_dtype != get_TensorProtoDataType(bias):
bias = g.op("Cast", bias, to_i=inp_dtype)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma)
ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "v", "f", "b")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
@symbolic_helper.parse_args("v", "v", "v", "f", "i", "b")
def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
......@@ -352,9 +352,9 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
)
return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b")
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
scale_inv, fp8_tensor, otype, sm_margin, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd_fp8"""
# pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
......@@ -362,13 +362,13 @@ def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "f", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
@symbolic_helper.parse_args("v", "v", "f", "i", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, sm_margin, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd"""
# pylint: disable=unused-argument
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment