"superbench/vscode:/vscode.git/clone" did not exist on "0972b223a1a0e0684ead3c1ecc202273e9442494"
Unverified Commit 275902fd authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Add margin for LayerNorm kernel SM usage (#57)



* Add margin for LayerNorm kernel SM usage
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert stylistic changes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 7f270330
......@@ -239,6 +239,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
return tex.layernorm_fwd_fp8(
......@@ -250,6 +251,7 @@ def layernorm_fwd_fp8(
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
)
......
......@@ -373,7 +373,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma
const at::Tensor &gamma,
const int sm_margin
) {
auto dx = at::empty_like(x);
auto dgamma = at::empty_like(gamma);
......@@ -393,7 +394,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Alloc space for Tensors.
......@@ -418,7 +419,7 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(),
dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(),
dbeta_part.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return { dx, dgamma, dbeta };
......@@ -432,7 +433,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
const int sm_margin
) {
using namespace transformer_engine;
......@@ -457,7 +459,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
......@@ -476,7 +478,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
......@@ -495,7 +497,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype);
input, weight, bias, eps, scale, amax, scale_inv, otype, 0);
return out[0];
}
......@@ -503,7 +505,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
float eps,
const int sm_margin
) {
using namespace transformer_engine;
......@@ -526,7 +529,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
......@@ -545,7 +548,7 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
// Actual call to fwd kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
......@@ -559,7 +562,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps);
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0);
return out[0];
}
......
......@@ -81,7 +81,8 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
const at::Tensor &mu,
const at::Tensor &rsigma,
const at::Tensor &gamma
const at::Tensor &gamma,
const int sm_margin
);
......@@ -92,7 +93,8 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
const int sm_margin
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
......@@ -108,7 +110,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps
float eps,
const int sm_margin
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
......
......@@ -597,7 +597,9 @@ class _LayerNormLinear(torch.autograd.Function):
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
is_training: bool
is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -627,6 +629,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
)
else:
mu = rsigma = None
......@@ -642,7 +645,7 @@ class _LayerNormLinear(torch.autograd.Function):
else:
if is_training:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
......@@ -657,7 +660,9 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
......@@ -749,6 +754,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.return_layernorm_output = return_layernorm_output
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -915,7 +921,7 @@ class _LayerNormLinear(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
)
if not ctx.use_bias:
......@@ -942,6 +948,8 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -1118,6 +1126,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
else:
self.gemm_bias_unfused_add = False
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
......@@ -1189,6 +1204,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.parallel_mode,
self.return_layernorm_output,
self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
out = fwd_fn(*args)
......@@ -1779,7 +1796,9 @@ class _LayerNormMLP(torch.autograd.Function):
return_layernorm_output: bool,
bias_gelu_nvfusion: bool,
set_parallel_mode: bool,
is_training: bool
is_training: bool,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -1808,6 +1827,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
)
else:
ln_out = layernorm_fwd_fp8_inf(
......@@ -1821,7 +1841,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
ln_out = cast_to_fp8(
ln_out_return,
......@@ -1831,7 +1851,9 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
if is_training:
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps
......@@ -1988,6 +2010,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
ctx.return_layernorm_output = return_layernorm_output
ctx.set_parallel_mode = set_parallel_mode
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
# Row Parallel Linear
if set_parallel_mode and sequence_parallel:
......@@ -2281,7 +2304,7 @@ class _LayerNormMLP(torch.autograd.Function):
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
)
if not ctx.use_bias:
......@@ -2313,6 +2336,8 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -2527,6 +2552,13 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.size_per_partition, seq_length, micro_batch_size
)
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
......@@ -2590,6 +2622,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bias_gelu_nvfusion,
self.set_parallel_mode,
self.training,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
out = fwd_fn(*args)
......@@ -2618,6 +2652,8 @@ class _LayerNorm(torch.autograd.Function):
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -2625,9 +2661,10 @@ class _LayerNorm(torch.autograd.Function):
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps)
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
return ln_out.view_as(inp)
@staticmethod
......@@ -2638,9 +2675,9 @@ class _LayerNorm(torch.autograd.Function):
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight
d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None
class LayerNorm(torch.nn.Module):
......@@ -2695,6 +2732,13 @@ class LayerNorm(torch.nn.Module):
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
......@@ -2725,4 +2769,11 @@ class LayerNorm(torch.nn.Module):
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps)
return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment