"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "a41bf7118b1ee71b72bbf9fb040183a0ef1665a9"
Unverified Commit c6cbcc85 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch] Integrate GPT OSS Swiglu in TransformerLayer (#2312)



* changes working
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add support for onnx, minor comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* greptile review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/transformer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/module/layernorm_mlp.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update transformer_engine/pytorch/transformer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert the name change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
parent a8e4346e
...@@ -68,7 +68,7 @@ if fp8_available: ...@@ -68,7 +68,7 @@ if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None) fp8_recipes.append(None)
supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
......
...@@ -122,6 +122,7 @@ all_activations = [ ...@@ -122,6 +122,7 @@ all_activations = [
"sreglu", "sreglu",
"silu", "silu",
"swiglu", "swiglu",
"clamped_swiglu",
] ]
all_normalizations = ["LayerNorm", "RMSNorm"] all_normalizations = ["LayerNorm", "RMSNorm"]
...@@ -547,7 +548,7 @@ def test_sanity_layernorm_mlp( ...@@ -547,7 +548,7 @@ def test_sanity_layernorm_mlp(
sigma = 0.023 sigma = 0.023
init_method = init_method_normal(sigma) init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702}
block = LayerNormMLP( block = LayerNormMLP(
config.hidden_size, config.hidden_size,
4 * config.hidden_size, 4 * config.hidden_size,
...@@ -555,6 +556,7 @@ def test_sanity_layernorm_mlp( ...@@ -555,6 +556,7 @@ def test_sanity_layernorm_mlp(
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
activation=activation, activation=activation,
activation_params=activation_params,
normalization=normalization, normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
......
...@@ -99,6 +99,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -99,6 +99,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None), "sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None), "silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
} }
if recipe.delayed() or recipe.mxfp8(): if recipe.delayed() or recipe.mxfp8():
# Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu]
...@@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None), "sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu),
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
} }
# no activation fusion written yet # no activation fusion written yet
# Per-tensor current scaling or fp8 blockwise scaling or custom quantization: [] # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: []
...@@ -135,6 +137,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): ...@@ -135,6 +137,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None):
"sreglu": (tex.sreglu, tex.dsreglu, None), "sreglu": (tex.sreglu, tex.dsreglu, None),
"silu": (tex.silu, tex.dsilu, None), "silu": (tex.silu, tex.dsilu, None),
"swiglu": (tex.swiglu, tex.dswiglu, None), "swiglu": (tex.swiglu, tex.dswiglu, None),
"clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None),
} }
raise NotImplementedError(f"Unhandled recipe type {recipe}") raise NotImplementedError(f"Unhandled recipe type {recipe}")
...@@ -199,6 +202,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -199,6 +202,7 @@ class _LayerNormMLP(torch.autograd.Function):
bwd_ln_sm_margin: int, bwd_ln_sm_margin: int,
zero_centered_gamma: bool, zero_centered_gamma: bool,
activation: str, activation: str,
activation_params: Optional[dict],
normalization: str, normalization: str,
ub_overlap_ag: bool, ub_overlap_ag: bool,
ub_overlap_rs: bool, ub_overlap_rs: bool,
...@@ -440,6 +444,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -440,6 +444,7 @@ class _LayerNormMLP(torch.autograd.Function):
# ACTIVATION - sometimes activation is fused with the GEMM above. # ACTIVATION - sometimes activation is fused with the GEMM above.
fc1_out_without_bias = None fc1_out_without_bias = None
act_params = activation_params or {}
if bias_gelu_fusion: if bias_gelu_fusion:
fc1_out = None fc1_out = None
...@@ -449,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -449,7 +454,7 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, _, fc1_out, _ = fc1_outputs act_out, _, fc1_out, _ = fc1_outputs
elif debug: elif debug:
fc1_out, *_ = fc1_outputs fc1_out, *_ = fc1_outputs
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out) act_out = fc2_input_quantizer(act_out)
else: else:
fc1_out, *_ = fc1_outputs fc1_out, *_ = fc1_outputs
...@@ -457,19 +462,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -457,19 +462,19 @@ class _LayerNormMLP(torch.autograd.Function):
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if recipe.float8_block_scaling(): if recipe.float8_block_scaling():
# tex.quantize does not support GELU fusion for blockwise # tex.quantize does not support GELU fusion for blockwise
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None, **act_params)
act_out = tex.quantize(act_out, fc2_input_quantizer) act_out = tex.quantize(act_out, fc2_input_quantizer)
elif recipe.custom(): elif recipe.custom():
# tex.quantize does not support custom quantizers # tex.quantize does not support custom quantizers
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None, **act_params)
act_out = fc2_input_quantizer(act_out) act_out = fc2_input_quantizer(act_out)
else: else:
act_out = activation_func(fc1_out, fc2_input_quantizer) act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
else: else:
if fp8_calibration: if fp8_calibration:
act_out = activation_func(fc1_out, None) act_out = activation_func(fc1_out, None, **act_params)
else: else:
act_out = activation_func(fc1_out, fc2_input_quantizer) act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params)
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(fc1_out) clear_tensor_data(fc1_out)
...@@ -624,6 +629,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -624,6 +629,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.device = device ctx.device = device
ctx.activation_dtype = activation_dtype ctx.activation_dtype = activation_dtype
ctx.activation = activation ctx.activation = activation
ctx.activation_params = activation_params
ctx.fp8 = fp8 ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
...@@ -1002,6 +1008,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1002,6 +1008,7 @@ class _LayerNormMLP(torch.autograd.Function):
# -------------------------------------------------- # --------------------------------------------------
# bias computation # bias computation
act_params = ctx.activation_params or {}
fc1_bias_grad = None fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False fuse_gemm_and_bias_fc1_wgrad = False
if ctx.fc1_grad_output_quantizer is not None: if ctx.fc1_grad_output_quantizer is not None:
...@@ -1015,7 +1022,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1015,7 +1022,7 @@ class _LayerNormMLP(torch.autograd.Function):
dact = ctx.fc1_grad_output_quantizer(dact) dact = ctx.fc1_grad_output_quantizer(dact)
elif ctx.debug: elif ctx.debug:
dact_func = _act_func(ctx.activation)[1] dact_func = _act_func(ctx.activation)[1]
dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params)
fc1_bias_grad = dact.sum(dim=0) fc1_bias_grad = dact.sum(dim=0)
dact = ctx.fc1_grad_output_quantizer(dact) dact = ctx.fc1_grad_output_quantizer(dact)
elif ( elif (
...@@ -1027,7 +1034,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1027,7 +1034,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[2] )[2]
fc1_bias_grad, dact = dbias_dact_quantize_func( fc1_bias_grad, dact = dbias_dact_quantize_func(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer fc2_dgrad,
fc1_out.to(ctx.activation_dtype),
ctx.fc1_grad_output_quantizer,
**act_params,
) # quantize bgrad gelu fused ) # quantize bgrad gelu fused
else: else:
# Fusion: gemm + gelu, # Fusion: gemm + gelu,
...@@ -1036,7 +1046,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1036,7 +1046,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation, ctx.fp8_recipe if ctx.fp8 else None ctx.activation, ctx.fp8_recipe if ctx.fp8 else None
)[1] )[1]
dact = activation_func_bwd( dact = activation_func_bwd(
fc2_dgrad, fc1_out.to(ctx.activation_dtype), None fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params
) # activation in high precision ) # activation in high precision
if ctx.fp8: if ctx.fp8:
...@@ -1401,6 +1411,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1401,6 +1411,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, # bwd_ln_sm_margin None, # bwd_ln_sm_margin
None, # zero_centered_gamma None, # zero_centered_gamma
None, # activation None, # activation
None, # activation_params
None, # normalization None, # normalization
None, # ub_overlap_ag None, # ub_overlap_ag
None, # ub_overlap_rs None, # ub_overlap_rs
...@@ -1436,7 +1447,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1436,7 +1447,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'. 'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : dict, default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...@@ -1537,6 +1552,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1537,6 +1552,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
activation: str = "gelu", activation: str = "gelu",
activation_params: Optional[dict] = None,
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -1564,6 +1580,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1564,6 +1580,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!"
self.use_bias = bias self.use_bias = bias
self.activation = activation self.activation = activation
self.activation_params = activation_params
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
...@@ -1643,7 +1660,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1643,7 +1660,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.layer_norm_bias = None self.layer_norm_bias = None
# FC1 init # FC1 init
if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]: if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]:
fc1_output_features = 2 * self.size_per_partition fc1_output_features = 2 * self.size_per_partition
else: else:
fc1_output_features = self.size_per_partition fc1_output_features = self.size_per_partition
...@@ -1897,6 +1914,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1897,6 +1914,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.bwd_ln_sm_margin, self.bwd_ln_sm_margin,
self.zero_centered_gamma, self.zero_centered_gamma,
self.activation, self.activation,
self.activation_params,
self.normalization, self.normalization,
self.ub_overlap_ag, self.ub_overlap_ag,
self.ub_overlap_rs, self.ub_overlap_rs,
...@@ -2026,6 +2044,19 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2026,6 +2044,19 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias) fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias)
fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32
act_params = self.activation_params or {}
# Default params for clamped_swiglu in Transformer Engine
clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get(
"alpha", 1.702
)
def _clamped_swiglu(x, limit, alpha):
x_glu, x_linear = x.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y = out_glu * (x_linear + 1)
return y
activation_map = { activation_map = {
"gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"),
...@@ -2040,6 +2071,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2040,6 +2071,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
* x.chunk(2, -1)[1], * x.chunk(2, -1)[1],
"silu": torch.nn.functional.silu, "silu": torch.nn.functional.silu,
"swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1],
"clamped_swiglu": lambda x: _clamped_swiglu(
x, clamped_swiglu_limit, clamped_swiglu_alpha
),
} }
if self.activation not in activation_map: if self.activation not in activation_map:
raise ValueError(f"Unsupported activation in onnx export: {self.activation}") raise ValueError(f"Unsupported activation in onnx export: {self.activation}")
......
...@@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu' activation : str, default = 'gelu'
Type of activation used in MLP block. Type of activation used in MLP block.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'. 'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : Optional[dict], default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the responsibility to ensure all parameters are moved to the GPU before running the
...@@ -310,6 +315,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -310,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_wgrad: bool = True, ub_bulk_wgrad: bool = True,
bias: bool = True, bias: bool = True,
activation: str = "gelu", activation: str = "gelu",
activation_params: Optional[dict] = None,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd", attn_input_format: str = "sbhd",
...@@ -475,6 +481,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -475,6 +481,7 @@ class TransformerLayer(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs, ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag, ub_overlap_ag=ub_overlap_ag,
activation=activation, activation=activation,
activation_params=activation_params,
normalization=normalization, normalization=normalization,
device=device, device=device,
name=name + ".layernorm_mlp" if name is not None else None, name=name + ".layernorm_mlp" if name is not None else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment