Unverified Commit d90bf212 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix usage of return_bias argument (#114)



* fix usage of return_bias argument
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06486a00
...@@ -1123,6 +1123,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1123,6 +1123,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
...@@ -1187,7 +1188,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1187,7 +1188,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
stride=1, stride=1,
) )
if self.use_bias or self.return_bias: if self.use_bias:
self.register_buffer("bias_tensor", self.register_buffer("bias_tensor",
torch.empty( torch.empty(
self.out_features, self.out_features,
...@@ -1229,7 +1230,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1229,7 +1230,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
stride=1, stride=1,
) )
if self.use_bias or self.return_bias: if self.use_bias:
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
...@@ -1246,9 +1247,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1246,9 +1247,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias: if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True self.gemm_bias_unfused_add = True
self.use_bias = False
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
...@@ -1331,7 +1331,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1331,7 +1331,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.weight1_fp8 if self.fp8 else None, self.weight1_fp8 if self.fp8 else None,
self.weight1_t_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
bias_tensor, bias_tensor,
self.use_bias, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
...@@ -1776,6 +1776,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1776,6 +1776,7 @@ class Linear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.parameters_split = parameters_split self.parameters_split = parameters_split
if tp_group is None: if tp_group is None:
...@@ -1819,7 +1820,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1819,7 +1820,7 @@ class Linear(TransformerEngineBaseModule):
stride=1, stride=1,
) )
if self.use_bias or self.return_bias: if self.use_bias:
self.register_buffer("bias_tensor", self.register_buffer("bias_tensor",
torch.empty( torch.empty(
self.out_features, self.out_features,
...@@ -1861,7 +1862,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1861,7 +1862,7 @@ class Linear(TransformerEngineBaseModule):
stride=1, stride=1,
) )
if self.use_bias or self.return_bias: if self.use_bias:
self.register_parameter( self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
) )
...@@ -1878,9 +1879,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -1878,9 +1879,8 @@ class Linear(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.use_bias: if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True self.gemm_bias_unfused_add = True
self.use_bias = False
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
...@@ -1946,7 +1946,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1946,7 +1946,7 @@ class Linear(TransformerEngineBaseModule):
self.weight1_t_fp8 if self.fp8 else None, self.weight1_t_fp8 if self.fp8 else None,
inp, inp,
bias_tensor, bias_tensor,
self.use_bias, self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
self.fp8_calibration, self.fp8_calibration,
...@@ -2667,6 +2667,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2667,6 +2667,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1"))) self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
self.set_parallel_mode = set_parallel_mode self.set_parallel_mode = set_parallel_mode
...@@ -2759,7 +2760,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2759,7 +2760,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
stride=1, stride=1,
) )
if self.use_bias or self.return_bias: if self.use_bias:
self.fc2_bias = Parameter( self.fc2_bias = Parameter(
torch.empty( torch.empty(
hidden_size, device=torch.cuda.current_device(), dtype=params_dtype hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
...@@ -2770,9 +2771,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2770,9 +2771,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM # So it cannot be fused with the GEMM
if self.set_parallel_mode and self.use_bias: if self.set_parallel_mode and self.apply_bias:
self.gemm_bias_unfused_add = True self.gemm_bias_unfused_add = True
self.use_bias = False
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
...@@ -2845,7 +2845,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2845,7 +2845,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.weight2_fp8 if self.fp8 else None, self.weight2_fp8 if self.fp8 else None,
self.weight2_t_fp8 if self.fp8 else None, self.weight2_t_fp8 if self.fp8 else None,
self.fc2_bias, self.fc2_bias,
self.use_bias, self.apply_bias and not self.gemm_bias_unfused_add,
self.eps, self.eps,
is_first_microbatch, is_first_microbatch,
self.fp8, self.fp8,
......
...@@ -607,7 +607,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -607,7 +607,7 @@ class MultiHeadAttention(torch.nn.Module):
hidden_size, hidden_size,
hidden_size, hidden_size,
init_method=output_layer_init_method, init_method=output_layer_init_method,
bias=False, bias=True,
return_bias=True, return_bias=True,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
**common_gemm_kwargs, **common_gemm_kwargs,
...@@ -1059,7 +1059,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -1059,7 +1059,7 @@ class TransformerLayer(torch.nn.Module):
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method, init_method=init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=output_layer_init_method,
bias=False, bias=True,
return_bias=True, return_bias=True,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
params_dtype=params_dtype, params_dtype=params_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment