Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
8ddd246f
Commit
8ddd246f
authored
Feb 26, 2021
by
Jiezhong Qiu
Browse files
use magatron's init method for ffn
parent
5e5b4044
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
2 deletions
+19
-2
fmoe/megatron.py
fmoe/megatron.py
+19
-2
No files found.
fmoe/megatron.py
View file @
8ddd246f
...
...
@@ -30,6 +30,20 @@ class _FakeMegatronMLP(nn.Module):
x
=
self
.
fc2
(
x
)
return
x
,
torch
.
zeros_like
(
x
)
def
_magatron_init_method
(
self
,
rng
,
sigma
):
r
'''
Init method based on N(0, sigma).
Copied from Megatron-LM
'''
device
=
self
.
weight
.
device
dtype
=
self
.
weight
.
dtype
weight
=
rng
.
normal
(
loc
=
0.0
,
scale
=
sigma
,
size
=
tuple
(
self
.
weight
.
size
()))
self
.
weight
.
data
=
torch
.
tensor
(
weight
,
dtype
=
dtype
,
device
=
device
)
if
self
.
bias
is
not
None
:
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
def
_random_init_weight
(
self
,
rng
):
r
'''
...
...
@@ -71,6 +85,8 @@ class MegatronMLP(FMoETransformerMLP):
expert_dp_comm
=
'none'
if
args
.
distributed_experts
else
'dp'
)
self
.
hidden_size
=
args
.
hidden_size
self
.
rank
=
args
.
rank
self
.
sigma
=
args
.
init_method_std
self
.
num_layers
=
args
.
num_layers
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -80,8 +96,9 @@ class MegatronMLP(FMoETransformerMLP):
additional numpy rng is used.
'''
rng
=
np
.
random
.
default_rng
(
np
.
random
.
randint
(
2048
)
+
self
.
rank
)
_random_init_weight
(
self
.
experts
.
htoh4
,
rng
)
_random_init_weight
(
self
.
experts
.
h4toh
,
rng
)
_magatron_init_method
(
self
.
experts
.
htoh4
,
rng
,
self
.
sigma
)
std
=
self
.
sigma
/
math
.
sqrt
(
2.0
*
self
.
num_layers
)
_magatron_init_method
(
self
.
experts
.
h4toh
,
rng
,
std
)
def
forward
(
self
,
inp
):
return
super
().
forward
(
inp
),
torch
.
zeros
(
self
.
hidden_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment