Unverified Commit c6cbcc85 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch] Integrate GPT OSS Swiglu in TransformerLayer (#2312)



* changes working
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support for onnx, minor comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* greptile review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/transformer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/transformer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert the name change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent a8e4346e
......@@ -68,7 +68,7 @@ if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
......
......@@ -122,6 +122,7 @@ all_activations = [
"sreglu",
"silu",
"swiglu",
"clamped_swiglu",
]
all_normalizations = ["LayerNorm", "RMSNorm"]
......@@ -547,7 +548,7 @@ def test_sanity_layernorm_mlp(
sigma = 0.023
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
block = LayerNormMLP(
config.hidden_size,
4 * config.hidden_size,
......@@ -555,6 +556,7 @@ def test_sanity_layernorm_mlp(
output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma,
activation=activation,
activation_params=activation_params,
normalization=normalization,
params_dtype=dtype,
device="cuda",
......
......@@ -99,6 +99,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
......@@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
# no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
......@@ -135,6 +137,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
}
raise NotImplementedError(f"Unhandled recipe type {recipe}")
......@@ -199,6 +202,7 @@ class _LayerNormMLP(torch.autograd.Function):
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
activation: str,
activation_params: Optional[dict],
normalization: str,
ub_overlap_ag: bool,
ub_overlap_rs: bool,
......@@ -440,6 +444,7 @@ class _LayerNormMLP(torch.autograd.Function):
# ACTIVATION - sometimes activation is fused with the GEMM above.
fc1_out_without_bias = None
act_params = activation_params or {}
if bias_gelu_fusion:
fc1_out = None
......@@ -449,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, _, fc1_out, _ = fc1_outputs
elif debug:
fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out)
else:
fc1_out, *_ = fc1_outputs
......@@ -457,19 +462,19 @@ class _LayerNormMLP(torch.autograd.Function):
recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom():
# tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
else:
if fp8_calibration:
act_out = activation_func(fc1_out, None)
act_out = activation_func(fc1_out, None, **act_params)
else:
act_out = activation_func(fc1_out, fc2_input_quantizer)
act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
if not is_grad_enabled:
clear_tensor_data(fc1_out)
......@@ -624,6 +629,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.device = device
ctx.activation_dtype = activation_dtype
ctx.activation = activation
ctx.activation_params = activation_params
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
......@@ -1002,6 +1008,7 @@ class _LayerNormMLP(torch.autograd.Function):
# --------------------------------------------------
# bias computation
act_params = ctx.activation_params or {}
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
if ctx.fc1_grad_output_quantizer is not None:
......@@ -1015,7 +1022,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug:
dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None)
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params)
fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact)
elif (
......@@ -1027,7 +1034,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2]
fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer
fc2_dgrad,
fc1_out.to(ctx.activation_dtype),
ctx.fc1_grad_output_quantizer,
**act_params,
) # quantize bgrad gelu fused
else:
# Fusion: gemm + gelu,
......@@ -1036,7 +1046,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[1]
dact = activation_func_bwd(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params
) # activation in high precision
if ctx.fp8:
......@@ -1401,6 +1411,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # bwd_ln_sm_margin
None, # zero_centered_gamma
None, # activation
None, # activation_params
None, # normalization
None, # ub_overlap_ag
None, # ub_overlap_rs
......@@ -1436,7 +1447,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : dict, default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -1537,6 +1552,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
bias: bool = True,
normalization: str = "LayerNorm",
activation: str = "gelu",
activation_params: Optional[dict] = None,
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None,
......@@ -1564,6 +1580,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
self.use_bias = bias
self.activation = activation
self.activation_params = activation_params
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
......@@ -1643,7 +1660,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None
# FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]:
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]:
fc1_output_features = 2 * self.size_per_partition
else:
fc1_output_features = self.size_per_partition
......@@ -1897,6 +1914,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
......@@ -2026,6 +2044,19 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
act_params = self.activation_params or {}
# Default params for clamped_swiglu in Transformer Engine
clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get(
"alpha", 1.702
)
def _clamped_swiglu(x, limit, alpha):
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
return y
activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
......@@ -2040,6 +2071,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
* x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"clamped_swiglu": lambda x: _clamped_swiglu(
x, clamped_swiglu_limit, clamped_swiglu_alpha
),
}
if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
......
......@@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : Optional[dict], default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......@@ -310,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_wgrad: bool = True,
bias: bool = True,
activation: str = "gelu",
activation_params: Optional[dict] = None,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
......@@ -475,6 +481,7 @@ class TransformerLayer(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
activation=activation,
activation_params=activation_params,
normalization=normalization,
device=device,
name=name + ".layernorm_mlp" if name is not None else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment