Unverified Commit 178f1365 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Fix bias initialization introduced in #596 (#622)


Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
parent f196d14b
......@@ -781,7 +781,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype)
)
self.register_parameter('layer_norm_bias', layer_norm_bias)
self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else:
self.layer_norm_bias = None
......@@ -873,7 +874,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else:
......
......@@ -1213,7 +1213,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
layer_norm_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
self.register_parameter('layer_norm_bias', layer_norm_bias)
self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else:
self.layer_norm_bias = None
......@@ -1240,7 +1241,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype)
)
self.register_parameter('fc1_bias', fc1_bias)
self.register_parameter('fc1_bias', fc1_bias,
init_fn=init_method_constant(0.0))
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition
else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
......@@ -1260,7 +1262,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype)
)
self.register_parameter('fc2_bias', fc2_bias)
self.register_parameter('fc2_bias', fc2_bias,
init_fn=init_method_constant(0.0))
# RPL
if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition
......
......@@ -26,6 +26,7 @@ from ..utils import (
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
init_method_constant,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
......@@ -764,7 +765,8 @@ class Linear(TransformerEngineBaseModule):
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment