Unverified Commit 82dde778 authored by ngoyal2707's avatar ngoyal2707 Committed by GitHub
Browse files

make bias configurable (#130)



* made bias configurable
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* removed commented lines
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>

* fixed incorrect call to fused bias dropout add kernel
Signed-off-by: default avatarNaman Goyal <naman@fb.com>

* Update transformer_engine/pytorch/jit.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Separate FC1 and FC2 use_bias args; solves all ci errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* jit fusion improvement
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Docs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarNaman Goyal <naman@fb.com>
Signed-off-by: default avatarngoyal2707 <ngoyal2707@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarNaman Goyal <naman@fb.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5d937c57
......@@ -105,7 +105,6 @@ for s in args:
if s.startswith("--framework="):
framework = s.replace("--framework=", "")
sys.argv.remove(s)
if framework not in supported_frameworks.keys():
raise ValueError("Unsupported framework " + framework)
......
......@@ -351,7 +351,8 @@ def test_sanity_layernorm_mlp(dtype, bs, fp8_recipe, model, skip_wgrad, zero_cen
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
@pytest.mark.parametrize("bias", all_boolean)
def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias):
if fp8_recipe is not None and not fp8_available:
pytest.skip(reason_for_no_fp8)
......@@ -375,6 +376,7 @@ def test_sanity_gpt(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_gamm
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
zero_centered_gamma=zero_centered_gamma,
bias=bias,
)
.to(dtype=dtype)
.cuda()
......
......@@ -97,7 +97,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
// ////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector<size_t> &shape) {
return std::reduce(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>());
}
} // namespace rmsnorm
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""NVFuser functions and JIT utilities"""
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch
......@@ -36,6 +36,15 @@ def bias_gelu_fused_(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def gelu_fused_(inp: torch.Tensor) -> torch.Tensor:
"""
GeLU fused, this is copy of bias_gelu_fused cause jit fusion doesn't allow conditioning.
"""
x = inp
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
......@@ -55,18 +64,39 @@ def bgrad_dgelu_fused_(
return bgrad, dgelu
@torch.jit.script
def dgelu_fused_(
grad_output: torch.Tensor, inp: torch.Tensor
) -> torch.Tensor:
"""
Dgelu fused, this is copy of bgrad_dgelu_fused_ cause jit fusion doesn't allow conditioning.
"""
x = inp
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
dgelu = ff * grad_output
return dgelu
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
return bias_gelu_fused_(inp, bias)
if bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
return bgrad_dgelu_fused_(grad_output, inp, bias)
if bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
def bias_dropout_add(
......
......@@ -2026,11 +2026,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_weight_fp8: Union[torch.Tensor, None],
fc1_weight_t_fp8: Union[torch.Tensor, None],
fc1_bias: torch.Tensor,
use_fc1_bias: bool,
fc2_weight: torch.Tensor,
fc2_weight_fp8: Union[torch.Tensor, None],
fc2_weight_t_fp8: Union[torch.Tensor, None],
fc2_bias: torch.Tensor,
use_bias: bool,
use_fc2_bias: bool,
eps: float,
is_first_microbatch: Union[bool, None],
fp8: bool,
......@@ -2126,8 +2127,8 @@ class _LayerNormMLP(torch.autograd.Function):
if activation_dtype == torch.float32
else activation_dtype
)
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
if update_fp8_weights:
if is_grad_enabled:
......@@ -2176,7 +2177,7 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=True,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
......@@ -2199,16 +2200,18 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_bias,
use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
else:
# Cast for native AMP
fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
fc1_bias = cast_if_needed(fc1_bias, activation_dtype)
fc1_bias = (
cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias
)
fc2_bias = (
cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias
cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias
)
if fp8_calibration:
......@@ -2225,12 +2228,13 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype,
get_workspace(),
bias=fc1_bias,
use_bias=not bias_gelu_nvfusion,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion,
)
if bias_gelu_nvfusion:
fc1_out, _, _ = fc1_outputs
gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
else:
gelu_out, _, fc1_out = fc1_outputs
......@@ -2249,7 +2253,7 @@ class _LayerNormMLP(torch.autograd.Function):
activation_dtype,
get_workspace(),
bias=fc2_bias,
use_bias=use_bias,
use_bias=use_fc2_bias,
)
if is_grad_enabled:
ctx.save_for_backward(
......@@ -2272,7 +2276,8 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta = fp8_meta
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.use_fc1_bias = use_fc1_bias
ctx.use_fc2_bias = use_fc2_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
......@@ -2319,6 +2324,7 @@ class _LayerNormMLP(torch.autograd.Function):
fwd_scale_inverses,
) = ctx.saved_tensors
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
(
grad_output,
grad_output_c,
......@@ -2412,7 +2418,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
use_bias=False,
accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad
if ctx.fuse_wgrad_accumulation
......@@ -2467,7 +2473,7 @@ class _LayerNormMLP(torch.autograd.Function):
get_workspace(),
layout="NT",
grad=True,
use_bias=ctx.use_bias,
use_bias=ctx.use_fc2_bias,
accumulate=accumulate_wgrad_into_param_main_grad,
out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
......@@ -2573,9 +2579,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if not ctx.use_bias:
fc2_bias_grad = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
......@@ -2583,11 +2586,12 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad if fc1_weight.requires_grad else None,
None,
None,
fc1_bias_grad,
fc1_bias_grad if ctx.use_fc1_bias else None,
None,
fc2_wgrad if fc2_weight.requires_grad else None,
None,
None,
fc2_bias_grad,
fc2_bias_grad if ctx.use_fc2_bias else None,
None,
None,
None,
......@@ -2623,7 +2627,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the FC2 layer will not learn an additive bias.
if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -2666,7 +2670,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
if set to `True`, enables fusing of creation and accumulation of
the weight gradient.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
when set to `True`, this module will not apply the additive bias for FC2, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
......@@ -2772,14 +2776,17 @@ class LayerNormMLP(TransformerEngineBaseModule):
stride=1,
)
self.fc1_bias = Parameter(
torch.empty(
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
if self.use_bias:
self.fc1_bias = Parameter(
torch.empty(
self.size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
else:
self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False)
with torch.no_grad():
self.fc1_bias.zero_()
......@@ -2884,6 +2891,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
self.fc1_bias,
self.use_bias,
self.fc2_weight,
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
......
......@@ -494,6 +494,7 @@ class MultiHeadAttention(torch.nn.Module):
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
bias: bool = True,
) -> None:
super().__init__()
self.layer_number = (layer_number,)
......@@ -539,7 +540,7 @@ class MultiHeadAttention(torch.nn.Module):
3 * hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=True,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
......@@ -552,7 +553,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
3 * hidden_size,
init_method=init_method,
bias=True,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
......@@ -565,7 +566,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
eps=layernorm_epsilon,
init_method=init_method,
bias=True,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output,
......@@ -577,7 +578,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
hidden_size,
init_method=init_method,
bias=True,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
**common_gemm_kwargs,
......@@ -586,7 +587,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
2 * hidden_size,
init_method=init_method,
bias=True,
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
parameters_split=("key_", "value_") if not fuse_qkv_params else None,
......@@ -611,7 +612,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
hidden_size,
init_method=output_layer_init_method,
bias=True,
bias=bias,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
**common_gemm_kwargs,
......@@ -890,6 +891,9 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
Parallelism parameters
----------------------
set_parallel_mode : bool, default = `False`
......@@ -965,6 +969,7 @@ class TransformerLayer(torch.nn.Module):
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
bias: bool = True,
) -> None:
super().__init__()
......@@ -1039,6 +1044,7 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type=self_attn_mask_type,
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
)
if layer_type == "decoder":
......@@ -1048,6 +1054,7 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type="padding",
input_layernorm=True,
attention_type="cross",
bias=bias,
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
......@@ -1063,7 +1070,7 @@ class TransformerLayer(torch.nn.Module):
get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=True,
bias=bias,
return_bias=True,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
......@@ -1184,6 +1191,7 @@ class TransformerLayer(torch.nn.Module):
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)
if self.apply_residual_connection_post_layernorm and not self.output_layernorm:
attention_output, attention_bias, residual = self_attention_outputs
else:
......@@ -1200,18 +1208,22 @@ class TransformerLayer(torch.nn.Module):
bias_dropout_add_func = get_bias_dropout_add(self.training)
# Bias dropoout add.
if self.drop_path is None:
if self.drop_path is None and attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
if attention_bias.numel() != 0:
attention_output = attention_output + attention_bias
out = torch.nn.functional.dropout(
attention_output + attention_bias,
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + self.drop_path(out)
if self.drop_path is not None:
out = self.drop_path(out)
bda_output = residual + out
# Cross attention.
if self.layer_type == "decoder":
......@@ -1228,11 +1240,18 @@ class TransformerLayer(torch.nn.Module):
attention_output, attention_bias = inter_attention_outputs
residual = bda_output
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
if attention_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
bda_output = bias_dropout_add_func(
attention_output, attention_bias, residual, self.hidden_dropout
)
else:
out = torch.nn.functional.dropout(
attention_output,
p=self.hidden_dropout,
training=self.training,
)
bda_output = residual + out
# MLP.
mlp_outputs = self.layernorm_mlp(
bda_output, is_first_microbatch=is_first_microbatch
......@@ -1244,16 +1263,20 @@ class TransformerLayer(torch.nn.Module):
residual = bda_output
# Bias dropoout add.
if self.drop_path is None:
if self.drop_path is None and mlp_bias.numel() != 0:
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output, mlp_bias, residual, self.hidden_dropout
)
else:
if mlp_bias.numel() != 0:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout(
mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training
mlp_output, p=self.hidden_dropout, training=self.training
)
output = residual + self.drop_path(out)
if self.drop_path is not None:
out = self.drop_path(out)
output = residual + out
# For BERT like architectures.
if self.output_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment