Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
178f1365
Unverified
Commit
178f1365
authored
Jan 22, 2024
by
Marks101
Committed by
GitHub
Jan 22, 2024
Browse files
[PyTorch] Fix bias initialization introduced in #596 (#622)
Signed-off-by:
Markus Schnoes
<
markus.schnoes@gmx.de
>
parent
f196d14b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
6 deletions
+13
-6
transformer_engine/pytorch/module/layernorm_linear.py
transformer_engine/pytorch/module/layernorm_linear.py
+4
-2
transformer_engine/pytorch/module/layernorm_mlp.py
transformer_engine/pytorch/module/layernorm_mlp.py
+6
-3
transformer_engine/pytorch/module/linear.py
transformer_engine/pytorch/module/linear.py
+3
-1
No files found.
transformer_engine/pytorch/module/layernorm_linear.py
View file @
178f1365
...
...
@@ -781,7 +781,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
layer_norm_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
in_features
,
device
=
device
,
dtype
=
params_dtype
)
)
self
.
register_parameter
(
'layer_norm_bias'
,
layer_norm_bias
)
self
.
register_parameter
(
'layer_norm_bias'
,
layer_norm_bias
,
init_fn
=
init_method_constant
(
0.0
))
setattr
(
self
.
layer_norm_bias
,
"sequence_parallel"
,
self
.
sequence_parallel
)
# pylint: disable=access-member-before-definition
else
:
self
.
layer_norm_bias
=
None
...
...
@@ -873,7 +874,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
if
is_subview
:
bias
=
bias
[
split_start
:
split_end
]
bias
=
torch
.
nn
.
Parameter
(
bias
)
self
.
register_parameter
(
self
.
bias_names
[
i
],
bias
)
self
.
register_parameter
(
self
.
bias_names
[
i
],
bias
,
init_fn
=
init_method_constant
(
0.0
))
if
parallel_mode
==
"row"
:
bias
.
sequence_parallel
=
sequence_parallel
else
:
...
...
transformer_engine/pytorch/module/layernorm_mlp.py
View file @
178f1365
...
...
@@ -1213,7 +1213,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
layer_norm_bias
=
Parameter
(
torch
.
empty
(
hidden_size
,
device
=
device
,
dtype
=
params_dtype
)
)
self
.
register_parameter
(
'layer_norm_bias'
,
layer_norm_bias
)
self
.
register_parameter
(
'layer_norm_bias'
,
layer_norm_bias
,
init_fn
=
init_method_constant
(
0.0
))
setattr
(
self
.
layer_norm_bias
,
"sequence_parallel"
,
self
.
sequence_parallel
)
# pylint: disable=access-member-before-definition
else
:
self
.
layer_norm_bias
=
None
...
...
@@ -1240,7 +1241,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc1_bias
=
Parameter
(
torch
.
empty
(
fc1_output_features
,
device
=
device
,
dtype
=
params_dtype
)
)
self
.
register_parameter
(
'fc1_bias'
,
fc1_bias
)
self
.
register_parameter
(
'fc1_bias'
,
fc1_bias
,
init_fn
=
init_method_constant
(
0.0
))
set_tensor_model_parallel_attributes
(
self
.
fc1_bias
,
True
,
0
,
1
)
# pylint: disable=access-member-before-definition
else
:
self
.
fc1_bias
=
torch
.
Tensor
().
to
(
dtype
=
params_dtype
,
device
=
device
)
...
...
@@ -1260,7 +1262,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_bias
=
Parameter
(
torch
.
empty
(
hidden_size
,
device
=
device
,
dtype
=
params_dtype
)
)
self
.
register_parameter
(
'fc2_bias'
,
fc2_bias
)
self
.
register_parameter
(
'fc2_bias'
,
fc2_bias
,
init_fn
=
init_method_constant
(
0.0
))
# RPL
if
self
.
set_parallel_mode
:
setattr
(
self
.
fc2_bias
,
"sequence_parallel"
,
sequence_parallel
)
# pylint: disable=access-member-before-definition
...
...
transformer_engine/pytorch/module/linear.py
View file @
178f1365
...
...
@@ -26,6 +26,7 @@ from ..utils import (
cast_if_needed
,
assert_dim_for_fp8_exec
,
clear_tensor_data
,
init_method_constant
,
)
from
..distributed
import
(
set_tensor_model_parallel_attributes
,
...
...
@@ -764,7 +765,8 @@ class Linear(TransformerEngineBaseModule):
if
is_subview
:
bias
=
bias
[
split_start
:
split_end
]
bias
=
torch
.
nn
.
Parameter
(
bias
)
self
.
register_parameter
(
self
.
bias_names
[
i
],
bias
)
self
.
register_parameter
(
self
.
bias_names
[
i
],
bias
,
init_fn
=
init_method_constant
(
0.0
))
if
parallel_mode
==
"row"
:
bias
.
sequence_parallel
=
sequence_parallel
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment