Unverified Commit 82dde778 authored by ngoyal2707's avatar ngoyal2707 Committed by GitHub
Browse files

make bias configurable (#130)



* made bias configurable
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* removed commented lines
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>

* fixed incorrect call to fused bias dropout add kernel
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Separate FC1 and FC2 use_bias args; solves all ci errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* jit fusion improvement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNaman Goyal <naman@fb.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5d937c57
...@@ -105,7 +105,6 @@ for s in args: ...@@ -105,7 +105,6 @@ for s in args:
if s.startswith("--framework="): if s.startswith("--framework="):
framework = s.replace("--framework=", "") framework = s.replace("--framework=", "")
sys.argv.remove(s) sys.argv.remove(s)
if framework not in supported_frameworks.keys(): if framework not in supported_frameworks.keys():
raise ValueError("Unsupported framework " + framework) raise ValueError("Unsupported framework " + framework)
......
...@@ -351,7 +351,8 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen ...@@ -351,7 +351,8 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma): @pytest.mark.parametrize("bias", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias):
if fp8_recipe is not None and not fp8_available: if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
...@@ -375,6 +376,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm ...@@ -375,6 +376,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
apply_residual_connection_post_layernorm=False, apply_residual_connection_post_layernorm=False,
output_layernorm=False, output_layernorm=False,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
bias=bias,
) )
.to(dtype=dtype) .to(dtype=dtype)
.cuda() .cuda()
......
...@@ -97,7 +97,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype ...@@ -97,7 +97,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
// //////////////////////////////////////////////////////////////////////////////////////////////////// // ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) { inline size_t product(const std::vector<size_t> &shape) {
return std::reduce(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
} }
} // namespace rmsnorm } // namespace rmsnorm
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""NVFuser functions and JIT utilities""" """NVFuser functions and JIT utilities"""
from typing import Callable, Tuple from typing import Callable, Optional, Tuple
import torch import torch
...@@ -36,6 +36,15 @@ def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: ...@@ -36,6 +36,15 @@ def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
"""
GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
"""
x = inp
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
...@@ -55,18 +64,39 @@ def bgrad_dgelu_fused_( ...@@ -55,18 +64,39 @@ def bgrad_dgelu_fused_(
return bgrad, dgelu return bgrad, dgelu
@torch.jit.script
def dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor
) -> torch.Tensor:
"""
Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning.
"""
x = inp
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
dgelu = ff * grad_output
return dgelu
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_""" """Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return bias_gelu_fused_(inp, bias) if bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
def bgrad_dgelu_fused( def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`""" """Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return bgrad_dgelu_fused_(grad_output, inp, bias) if bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
def bias_dropout_add( def bias_dropout_add(
......
...@@ -2026,11 +2026,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2026,11 +2026,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_weight_fp8: Union[torch.Tensor, None], fc1_weight_fp8: Union[torch.Tensor, None],
fc1_weight_t_fp8: Union[torch.Tensor, None], fc1_weight_t_fp8: Union[torch.Tensor, None],
fc1_bias: torch.Tensor, fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor, fc2_weight: torch.Tensor,
fc2_weight_fp8: Union[torch.Tensor, None], fc2_weight_fp8: Union[torch.Tensor, None],
fc2_weight_t_fp8: Union[torch.Tensor, None], fc2_weight_t_fp8: Union[torch.Tensor, None],
fc2_bias: torch.Tensor, fc2_bias: torch.Tensor,
use_bias: bool, use_fc2_bias: bool,
eps: float, eps: float,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
fp8: bool, fp8: bool,
...@@ -2126,8 +2127,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2126,8 +2127,8 @@ class _LayerNormMLP(torch.autograd.Function):
if activation_dtype == torch.float32 if activation_dtype == torch.float32
else activation_dtype else activation_dtype
) )
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
if update_fp8_weights: if update_fp8_weights:
if is_grad_enabled: if is_grad_enabled:
...@@ -2176,7 +2177,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2176,7 +2177,7 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=fc1_bias, bias=fc1_bias,
use_bias=True, use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
) )
...@@ -2199,16 +2200,18 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2199,16 +2200,18 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=fc2_bias, bias=fc2_bias,
use_bias=use_bias, use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
) )
else: else:
# Cast for native AMP # Cast for native AMP
fc1_weight = cast_if_needed(fc1_weight, activation_dtype) fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight = cast_if_needed(fc2_weight, activation_dtype) fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
fc1_bias = cast_if_needed(fc1_bias, activation_dtype) fc1_bias = (
cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias
)
fc2_bias = ( fc2_bias = (
cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias
) )
if fp8_calibration: if fp8_calibration:
...@@ -2225,12 +2228,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2225,12 +2228,13 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=fc1_bias, bias=fc1_bias,
use_bias=not bias_gelu_nvfusion, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion, gelu=not bias_gelu_nvfusion,
) )
if bias_gelu_nvfusion: if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias) gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else: else:
gelu_out, _, fc1_out = fc1_outputs gelu_out, _, fc1_out = fc1_outputs
...@@ -2249,7 +2253,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2249,7 +2253,7 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype, activation_dtype,
get_workspace(), get_workspace(),
bias=fc2_bias, bias=fc2_bias,
use_bias=use_bias, use_bias=use_fc2_bias,
) )
if is_grad_enabled: if is_grad_enabled:
ctx.save_for_backward( ctx.save_for_backward(
...@@ -2272,7 +2276,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2272,7 +2276,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
...@@ -2319,6 +2324,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2319,6 +2324,7 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses, fwd_scale_inverses,
) = ctx.saved_tensors ) = ctx.saved_tensors
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -2412,7 +2418,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2412,7 +2418,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=False,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
...@@ -2467,7 +2473,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2467,7 +2473,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(), get_workspace(),
layout="NT", layout="NT",
grad=True, grad=True,
use_bias=ctx.use_bias, use_bias=ctx.use_fc2_bias,
accumulate=accumulate_wgrad_into_param_main_grad, accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
) )
...@@ -2573,9 +2579,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2573,9 +2579,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
if not ctx.use_bias:
fc2_bias_grad = None
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma, dgamma,
...@@ -2583,11 +2586,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2583,11 +2586,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad if fc1_weight.requires_grad else None, fc1_wgrad if fc1_weight.requires_grad else None,
None, None,
None, None,
fc1_bias_grad, fc1_bias_grad if ctx.use_fc1_bias else None,
None,
fc2_wgrad if fc2_weight.requires_grad else None, fc2_wgrad if fc2_weight.requires_grad else None,
None, None,
None, None,
fc2_bias_grad, fc2_bias_grad if ctx.use_fc2_bias else None,
None, None,
None, None,
None, None,
...@@ -2623,7 +2627,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2623,7 +2627,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
eps : float, default = 1e-5 eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the FC2 layer will not learn an additive bias. if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`. used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...@@ -2666,7 +2670,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2666,7 +2670,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if set to `True`, enables fusing of creation and accumulation of if set to `True`, enables fusing of creation and accumulation of
the weight gradient. the weight gradient.
return_bias : bool, default = `False` return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations. the bias addition can be fused to subsequent operations.
...@@ -2772,14 +2776,17 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2772,14 +2776,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
stride=1, stride=1,
) )
self.fc1_bias = Parameter( if self.use_bias:
torch.empty( self.fc1_bias = Parameter(
self.size_per_partition, torch.empty(
device=torch.cuda.current_device(), self.size_per_partition,
dtype=params_dtype, device=torch.cuda.current_device(),
dtype=params_dtype,
)
) )
) set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) else:
self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad(): with torch.no_grad():
self.fc1_bias.zero_() self.fc1_bias.zero_()
...@@ -2884,6 +2891,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2884,6 +2891,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.weight1_fp8 if self.fp8 else None, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias, self.fc1_bias,
self.use_bias,
self.fc2_weight, self.fc2_weight,
self.weight2_fp8 if self.fp8 else None, self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None, self.weight2_t_fp8 if self.fp8 else None,
......
...@@ -494,6 +494,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -494,6 +494,7 @@ class MultiHeadAttention(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_number = (layer_number,) self.layer_number = (layer_number,)
...@@ -539,7 +540,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -539,7 +540,7 @@ class MultiHeadAttention(torch.nn.Module):
3 * hidden_size, 3 * hidden_size,
eps=layernorm_epsilon, eps=layernorm_epsilon,
init_method=init_method, init_method=init_method,
bias=True, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
...@@ -552,7 +553,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -552,7 +553,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
3 * hidden_size, 3 * hidden_size,
init_method=init_method, init_method=init_method,
bias=True, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
...@@ -565,7 +566,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -565,7 +566,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
eps=layernorm_epsilon, eps=layernorm_epsilon,
init_method=init_method, init_method=init_method,
bias=True, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
...@@ -577,7 +578,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -577,7 +578,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
hidden_size, hidden_size,
init_method=init_method, init_method=init_method,
bias=True, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
**common_gemm_kwargs, **common_gemm_kwargs,
...@@ -586,7 +587,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -586,7 +587,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
2 * hidden_size, 2 * hidden_size,
init_method=init_method, init_method=init_method,
bias=True, bias=bias,
return_bias=False, return_bias=False,
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None, parameters_split=("key_", "value_") if not fuse_qkv_params else None,
...@@ -611,7 +612,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -611,7 +612,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
hidden_size, hidden_size,
init_method=output_layer_init_method, init_method=output_layer_init_method,
bias=True, bias=bias,
return_bias=True, return_bias=True,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
**common_gemm_kwargs, **common_gemm_kwargs,
...@@ -890,6 +891,9 @@ class TransformerLayer(torch.nn.Module): ...@@ -890,6 +891,9 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
set_parallel_mode : bool, default = `False` set_parallel_mode : bool, default = `False`
...@@ -965,6 +969,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -965,6 +969,7 @@ class TransformerLayer(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1039,6 +1044,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1039,6 +1044,7 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=self_attn_mask_type, attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm, input_layernorm=not output_layernorm,
attention_type="self", attention_type="self",
bias=bias,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -1048,6 +1054,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1048,6 +1054,7 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type="padding", attn_mask_type="padding",
input_layernorm=True, input_layernorm=True,
attention_type="cross", attention_type="cross",
bias=bias,
) )
# LayerNorm -> gelu(Linear + Bias) -> Linear # LayerNorm -> gelu(Linear + Bias) -> Linear
...@@ -1063,7 +1070,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1063,7 +1070,7 @@ class TransformerLayer(torch.nn.Module):
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
bias=True, bias=bias,
return_bias=True, return_bias=True,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype, params_dtype=params_dtype,
...@@ -1184,6 +1191,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1184,6 +1191,7 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch, is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention, checkpoint_core_attention=checkpoint_core_attention,
) )
if self.apply_residual_connection_post_layernorm and not self.output_layernorm: if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs attention_output, attention_bias, residual = self_attention_outputs
else: else:
...@@ -1200,18 +1208,22 @@ class TransformerLayer(torch.nn.Module): ...@@ -1200,18 +1208,22 @@ class TransformerLayer(torch.nn.Module):
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
# Bias dropoout add. # Bias dropoout add.
if self.drop_path is None: if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func( bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout attention_output, attention_bias, residual, self.hidden_dropout
) )
else: else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout( out = torch.nn.functional.dropout(
attention_output + attention_bias, attention_output,
p=self.hidden_dropout, p=self.hidden_dropout,
training=self.training, training=self.training,
) )
bda_output = residual + self.drop_path(out) if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out
# Cross attention. # Cross attention.
if self.layer_type == "decoder": if self.layer_type == "decoder":
...@@ -1228,11 +1240,18 @@ class TransformerLayer(torch.nn.Module): ...@@ -1228,11 +1240,18 @@ class TransformerLayer(torch.nn.Module):
attention_output, attention_bias = inter_attention_outputs attention_output, attention_bias = inter_attention_outputs
residual = bda_output residual = bda_output
with self.bias_dropout_add_exec_handler(): if attention_bias.numel() != 0:
bda_output = bias_dropout_add_func( with self.bias_dropout_add_exec_handler():
attention_output, attention_bias, residual, self.hidden_dropout bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
) )
bda_output = residual + out
# MLP. # MLP.
mlp_outputs = self.layernorm_mlp( mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch bda_output, is_first_microbatch=is_first_microbatch
...@@ -1244,16 +1263,20 @@ class TransformerLayer(torch.nn.Module): ...@@ -1244,16 +1263,20 @@ class TransformerLayer(torch.nn.Module):
residual = bda_output residual = bda_output
# Bias dropoout add. # Bias dropoout add.
if self.drop_path is None: if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler(): with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout mlp_output, mlp_bias, residual, self.hidden_dropout
) )
else: else:
if mlp_bias.numel() != 0:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout( out = torch.nn.functional.dropout(
mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training mlp_output, p=self.hidden_dropout, training=self.training
) )
output = residual + self.drop_path(out) if self.drop_path is not None:
out = self.drop_path(out)
output = residual + out
# For BERT like architectures. # For BERT like architectures.
if self.output_layernorm: if self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment