Unverified Commit c67bb2fc authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Adding other activation types to LayerNormMLP (#265)



* Added ReLU and GLU variants to common
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* pyTorch changes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* PyTorch C++ lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix storage errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Compute bgrad
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix numerical tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent df6f347f
......@@ -167,14 +167,73 @@ at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* Activations
**************************************************************************************************/
at::Tensor fp8_gelu(at::Tensor input,
at::Tensor gelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor relu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor geglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor reglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor swiglu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor drelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dgeglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dreglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
at::Tensor dswiglu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
const at::Tensor &x,
......
......@@ -47,17 +47,177 @@ at::Tensor cast_from_fp8_ts(const at::Tensor &input,
return output;
}
at::Tensor fp8_gelu_ts(at::Tensor input,
at::Tensor gelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor output = fp8_gelu(input,
scale[fp8_tensor],
amax[0][fp8_tensor],
scale_inv[fp8_tensor],
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = gelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor relu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = relu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor reglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = reglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor geglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = geglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
at::Tensor swiglu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}
if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}
if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}
at::Tensor output = swiglu(input,
s,
a,
s_inv,
otype_arg);
return output;
}
......@@ -171,7 +331,11 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
m.def("fp8_gelu_ts", &fp8_gelu_ts);
m.def("gelu_ts", &gelu_ts);
m.def("relu_ts", &relu_ts);
m.def("geglu_ts", &geglu_ts);
m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
......
......@@ -10,8 +10,6 @@ import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from .base import (
get_workspace,
_prepare_backward,
......@@ -42,23 +40,28 @@ from ..distributed import (
reduce_scatter_along_first_dim,
gather_along_first_dim,
)
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
fp8_gelu,
fp8_cast_transpose_bgrad_dgelu_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
__all__ = ["LayerNormMLP"]
def _act_func(activation: str):
funcs = {
'gelu': (tex.gelu, tex.dgelu),
'relu': (tex.relu, tex.drelu),
'geglu': (tex.geglu, tex.dgeglu),
'reglu': (tex.reglu, tex.dreglu),
'swiglu': (tex.swiglu, tex.dswiglu),
}
if activation not in funcs:
raise "Activation type " + activation + " is not supported!"
return funcs[activation]
class _LayerNormMLP(torch.autograd.Function):
"""LayerNormMLP semi-top level module
Calls custom cuda extensions.
......@@ -102,6 +105,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_split_ag: bool,
activation: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -114,6 +118,8 @@ class _LayerNormMLP(torch.autograd.Function):
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
activation_func = _act_func(activation)[0]
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
......@@ -140,7 +146,7 @@ class _LayerNormMLP(torch.autograd.Function):
if is_grad_enabled:
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
_, mu, rsigma = tex.layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
......@@ -153,7 +159,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out = ln_out,
)
else:
ln_out = layernorm_fwd_fp8_inf(
ln_out = tex.layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
......@@ -167,7 +173,7 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ln_out = cast_to_fp8(
ln_out = tex.cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -185,7 +191,7 @@ class _LayerNormMLP(torch.autograd.Function):
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
ln_out, mu, rsigma = tex.layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
......@@ -210,7 +216,7 @@ class _LayerNormMLP(torch.autograd.Function):
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
tex.fp8_cast_transpose_fused(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -219,7 +225,7 @@ class _LayerNormMLP(torch.autograd.Function):
transpose_out=fc1_weight_t_fp8,
)
fp8_cast_transpose_fused(
tex.fp8_cast_transpose_fused(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
......@@ -229,21 +235,21 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
fc1_weight_t_fp8 = None
fc1_weight_fp8 = cast_to_fp8(
fc1_weight_fp8 = tex.cast_to_fp8(
fc1_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)
fc2_weight_t_fp8 = None
fc2_weight_fp8 = cast_to_fp8(
fc2_weight_fp8 = tex.cast_to_fp8(
fc2_weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_WEIGHT,
fp8_dtype_forward,
)
fc1_out = fp8_gemm(
fc1_out = tex.fp8_gemm(
fc1_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -262,7 +268,7 @@ class _LayerNormMLP(torch.autograd.Function):
extra_output_tensor=ln_out if ub_split_ag else None,
)
gelu_out = fp8_gelu(
gelu_out = activation_func(
fc1_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
......@@ -281,7 +287,7 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_ = fp8_gemm(
_ = tex.fp8_gemm(
fc2_weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM2_WEIGHT,
......@@ -319,14 +325,14 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(fc1_weight).float()
fc1_outputs = gemm(
fc1_outputs = tex.gemm(
fc1_weight,
ln_out_total,
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion,
gelu=not bias_gelu_nvfusion and (activation == 'gelu'),
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
......@@ -334,10 +340,16 @@ class _LayerNormMLP(torch.autograd.Function):
if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else:
if activation == 'gelu':
gelu_out, _, fc1_out = fc1_outputs
else:
fc1_out, _, _ = fc1_outputs
gelu_out = activation_func(fc1_out,
None,
tex.FP8FwdTensors.GEMM2_INPUT,
TE_DType[fc1_out.dtype])
if fp8_calibration:
# amax of fc2 input
......@@ -358,7 +370,7 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
_, _, _ = gemm(
_, _, _ = tex.gemm(
fc2_weight,
gelu_out,
activation_dtype,
......@@ -388,6 +400,7 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
)
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.fp8 = fp8
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -448,6 +461,8 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
activation_func = _act_func(ctx.activation)[1]
if ctx.ub_bulk_dgrad:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
......@@ -504,7 +519,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
# FC2 DGRAD; Unconditional
fc2_dgrad = fp8_gemm(
fc2_dgrad = tex.fp8_gemm(
fc2_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_WEIGHT,
......@@ -525,7 +540,7 @@ class _LayerNormMLP(torch.autograd.Function):
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if fc2_weight.requires_grad:
gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
fc2_wgrad = fp8_gemm(
fc2_wgrad = tex.fp8_gemm(
gelu_out_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM2_INPUT,
......@@ -543,23 +558,33 @@ class _LayerNormMLP(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD,
)
fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
if ctx.activation == 'gelu':
fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused(
fc2_dgrad,
fc1_out,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
else:
dgelu = activation_func(fc2_dgrad, fc1_out,
TE_DType[fc2_dgrad.dtype])
fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused(
dgelu,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
fp8_dtype_backward,
)
else:
if fc2_weight.requires_grad:
gelu_out_c = cast_from_fp8(
gelu_out_c = tex.cast_from_fp8(
gelu_out,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM2_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc2_wgrad, _, _ = gemm(
fc2_wgrad, _, _ = tex.gemm(
gelu_out_c,
grad_output,
ctx.activation_dtype,
......@@ -573,11 +598,17 @@ class _LayerNormMLP(torch.autograd.Function):
else None,
)
if ctx.activation == 'gelu':
fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
fc2_dgrad, fc1_out, fc1_bias
)
else:
dgelu_no_fp8 = activation_func(fc2_dgrad,
fc1_out,
TE_DType[fc2_dgrad.dtype])
fc1_bias_grad = dgelu_no_fp8.sum(dim=0)
dgelu = cast_to_fp8(
dgelu = tex.cast_to_fp8(
dgelu_no_fp8,
ctx.fp8_meta["scaling_bwd"],
tex.FP8BwdTensors.GRAD_OUTPUT2,
......@@ -595,7 +626,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
_ = fp8_gemm(
_ = tex.fp8_gemm(
fc1_weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -613,13 +644,13 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
# FC2 DGRAD; Unconditional
fc2_dgrad, _, _ = gemm(
fc2_dgrad, _, _ = tex.gemm(
fc2_weight,
grad_output,
ctx.activation_dtype,
get_workspace(),
layout="NN",
gelu=not ctx.bias_gelu_nvfusion,
gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'),
grad=True,
gelu_input=fc1_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
......@@ -628,7 +659,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC2 WGRAD
if fc2_weight.requires_grad:
fc2_wgrad, fc2_bias_grad, _ = gemm(
fc2_wgrad, fc2_bias_grad, _ = tex.gemm(
gelu_out,
grad_output,
ctx.activation_dtype,
......@@ -640,10 +671,20 @@ class _LayerNormMLP(torch.autograd.Function):
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
if ctx.bias_gelu_nvfusion:
if ctx.bias_gelu_nvfusion and ctx.activation == 'gelu':
fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
else:
if ctx.activation == 'gelu':
dgelu = fc2_dgrad
else:
dgelu = activation_func(fc2_dgrad,
fc1_out,
TE_DType[fc2_dgrad.dtype])
# For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM
# and will not be calculated in case wgrad is not required.
if not fc1_weight.requires_grad:
fc1_bias_grad = dgelu.sum(dim=0)
fc1_dgrad_size = list(dgelu.size())
fc1_dgrad_size[1] = fc1_weight.size(1)
......@@ -655,7 +696,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
)
# FC1 DGRAD: Unconditional
_, _, _ = gemm(
_, _, _ = tex.gemm(
fc1_weight,
dgelu,
ctx.activation_dtype,
......@@ -685,7 +726,7 @@ class _LayerNormMLP(torch.autograd.Function):
# FC1 WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
fc1_wgrad = fp8_gemm(
fc1_wgrad = tex.fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -706,14 +747,14 @@ class _LayerNormMLP(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total_c = tex.cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
fc1_wgrad, _, _ = gemm(
fc1_wgrad, _, _ = tex.gemm(
ln_out_total_c,
dgelu_no_fp8,
ctx.activation_dtype,
......@@ -730,7 +771,7 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
# FC1 WGRAD
fc1_wgrad_outputs = gemm(
fc1_wgrad_outputs = tex.gemm(
ln_out_total,
dgelu,
ctx.activation_dtype,
......@@ -803,6 +844,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -821,6 +863,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -893,6 +938,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tp_size: int = 1,
init_method: Optional[Callable] = None,
bias: bool = True,
activation : str = "gelu",
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: torch.dtype = torch.float32,
......@@ -910,10 +956,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.activation = activation
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.bias_gelu_nvfusion = (bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) and
self.activation == 'gelu')
self.set_parallel_mode = set_parallel_mode
self.zero_centered_gamma = zero_centered_gamma
self.ub_bulk_wgrad = ub_bulk_wgrad
......@@ -963,10 +1011,14 @@ class LayerNormMLP(TransformerEngineBaseModule):
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()
if self.activation in ['reglu', 'geglu', 'swiglu']:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
# FC1 init
self.fc1_weight = Parameter(
torch.empty(
self.size_per_partition,
fc1_output_features,
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
......@@ -985,7 +1037,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.use_bias:
self.fc1_bias = Parameter(
torch.empty(
self.size_per_partition,
fc1_output_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
......@@ -1162,6 +1214,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_split_ag,
self.activation,
)
out = fwd_fn(*args)
......
......@@ -254,7 +254,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
register_custom_op_symbolic('tex_ts::fp8_gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::gelu_ts', onnx_fp8_gelu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
......@@ -136,6 +136,9 @@ class TransformerLayer(torch.nn.Module):
using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
Parallelism parameters
----------------------
......@@ -214,6 +217,7 @@ class TransformerLayer(torch.nn.Module):
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True,
activation: str = 'gelu'
) -> None:
super().__init__()
......@@ -316,9 +320,11 @@ class TransformerLayer(torch.nn.Module):
bias=bias,
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
# LayerNorm -> activation(Linear + Bias) -> Linear
# parallel_mode not supported for LayerNormMLP,
# FC1 is CPL and FC2 is RPL
# In the case of GLU activation, FC1 handles both
# Linear layers before the activation
self.layernorm_mlp = LayerNormMLP(
hidden_size,
ffn_hidden_size,
......@@ -342,6 +348,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
activation=activation,
)
self.hidden_dropout = hidden_dropout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment