"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "8f6c52485adb55c07d92e067d3de7ed6b4bc3615"
Unverified Commit 178f1365 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] Fix bias initialization introduced in #596 (#622)


Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
parent f196d14b
...@@ -781,7 +781,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -781,7 +781,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
layer_norm_bias = torch.nn.Parameter( layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
self.register_parameter('layer_norm_bias', layer_norm_bias) self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
...@@ -873,7 +874,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -873,7 +874,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
if is_subview: if is_subview:
bias = bias[split_start:split_end] bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias) self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
if parallel_mode == "row": if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel bias.sequence_parallel = sequence_parallel
else: else:
......
...@@ -1213,7 +1213,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1213,7 +1213,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
layer_norm_bias = Parameter( layer_norm_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
self.register_parameter('layer_norm_bias', layer_norm_bias) self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
...@@ -1240,7 +1241,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1240,7 +1241,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_bias = Parameter( fc1_bias = Parameter(
torch.empty(fc1_output_features, device=device, dtype=params_dtype) torch.empty(fc1_output_features, device=device, dtype=params_dtype)
) )
self.register_parameter('fc1_bias', fc1_bias) self.register_parameter('fc1_bias', fc1_bias,
init_fn=init_method_constant(0.0))
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition
else: else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
...@@ -1260,7 +1262,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1260,7 +1262,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_bias = Parameter( fc2_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
self.register_parameter('fc2_bias', fc2_bias) self.register_parameter('fc2_bias', fc2_bias,
init_fn=init_method_constant(0.0))
# RPL # RPL
if self.set_parallel_mode: if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition
......
...@@ -26,6 +26,7 @@ from ..utils import ( ...@@ -26,6 +26,7 @@ from ..utils import (
cast_if_needed, cast_if_needed,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
init_method_constant,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -764,7 +765,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -764,7 +765,8 @@ class Linear(TransformerEngineBaseModule):
if is_subview: if is_subview:
bias = bias[split_start:split_end] bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias) self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
if parallel_mode == "row": if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel bias.sequence_parallel = sequence_parallel
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment