Unverified Commit d90bf212 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix usage of return_bias argument (#114)



* fix usage of return_bias argument
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06486a00
......@@ -1123,6 +1123,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
......@@ -1187,7 +1188,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
stride=1,
)
if self.use_bias or self.return_bias:
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
......@@ -1229,7 +1230,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
stride=1,
)
if self.use_bias or self.return_bias:
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
......@@ -1246,9 +1247,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias:
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
......@@ -1331,7 +1331,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None,
bias_tensor,
self.use_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
self.fp8,
......@@ -1776,6 +1776,7 @@ class Linear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split
if tp_group is None:
......@@ -1819,7 +1820,7 @@ class Linear(TransformerEngineBaseModule):
stride=1,
)
if self.use_bias or self.return_bias:
if self.use_bias:
self.register_buffer("bias_tensor",
torch.empty(
self.out_features,
......@@ -1861,7 +1862,7 @@ class Linear(TransformerEngineBaseModule):
stride=1,
)
if self.use_bias or self.return_bias:
if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
......@@ -1878,9 +1879,8 @@ class Linear(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias:
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
......@@ -1946,7 +1946,7 @@ class Linear(TransformerEngineBaseModule):
self.weight1_t_fp8 if self.fp8 else None,
inp,
bias_tensor,
self.use_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
......@@ -2667,6 +2667,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode
......@@ -2759,7 +2760,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
stride=1,
)
if self.use_bias or self.return_bias:
if self.use_bias:
self.fc2_bias = Parameter(
torch.empty(
hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
......@@ -2770,9 +2771,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.set_parallel_mode and self.use_bias:
if self.set_parallel_mode and self.apply_bias:
self.gemm_bias_unfused_add = True
self.use_bias = False
else:
self.gemm_bias_unfused_add = False
......@@ -2845,7 +2845,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias,
self.use_bias,
self.apply_bias and not self.gemm_bias_unfused_add,
self.eps,
is_first_microbatch,
self.fp8,
......
......@@ -607,7 +607,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size,
hidden_size,
init_method=output_layer_init_method,
bias=False,
bias=True,
return_bias=True,
parallel_mode="row" if set_parallel_mode else None,
**common_gemm_kwargs,
......@@ -1059,7 +1059,7 @@ class TransformerLayer(torch.nn.Module):
get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
bias=False,
bias=True,
return_bias=True,
sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment