Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
c844413b
Commit
c844413b
authored
Mar 16, 2021
by
Jiezhong Qiu
Browse files
fix pylint
parent
49a4678c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
12 deletions
+11
-12
fmoe/megatron.py
fmoe/megatron.py
+9
-8
fmoe/transformer.py
fmoe/transformer.py
+2
-4
No files found.
fmoe/megatron.py
View file @
c844413b
...
@@ -5,9 +5,9 @@ See `examples/megatron` for usage instructions.
...
@@ -5,9 +5,9 @@ See `examples/megatron` for usage instructions.
"""
"""
import
os
import
os
import
math
import
math
import
numpy
as
np
import
random
import
random
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -392,6 +392,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
...
@@ -392,6 +392,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
def
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
"""Save a model checkpoint with expert parallel """
"""Save a model checkpoint with expert parallel """
# TODO: update patch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -405,15 +406,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -405,15 +406,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
),
flush
=
True
)
iteration
,
args
.
save
),
flush
=
True
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
# Arguments, iteration, and model.
# Arguments, iteration, and model.
state_dict
=
{}
state_dict
=
{}
state_dict
[
'args'
]
=
args
state_dict
[
'args'
]
=
args
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'checkpoint_version'
]
=
3.0
state_dict
[
'iteration'
]
=
iteration
state_dict
[
'iteration'
]
=
iteration
keep_vars
=
False
if
mpu
.
get_data_parallel_rank
()
==
0
else
True
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
state_dict
[
'model'
]
=
model
.
state_dict_for_save_checkpoint
(
keep_vars
=
keep_vars
)
keep_vars
=
(
mpu
.
get_data_parallel_rank
()
>
0
)
)
if
mpu
.
get_data_parallel_rank
()
!=
0
:
if
mpu
.
get_data_parallel_rank
()
!=
0
:
...
@@ -421,15 +420,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -421,15 +420,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict_new
=
state_dict
.
__class__
()
state_dict_new
=
state_dict
.
__class__
()
for
k
,
v
in
state_dict
.
items
():
for
k
,
v
in
state_dict
.
items
():
# megatron uses both dict and OrderedDict in its state_dict
# megatron uses both dict and OrderedDict in its state_dict
if
isinstance
(
v
,
OrderedDict
)
or
isinstance
(
v
,
dict
):
if
isinstance
(
v
,
(
OrderedDict
,
dict
)
)
:
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
v_new
=
extract_expert_param
(
v
,
expert_dp_comm
)
if
len
(
v_new
):
if
len
(
v_new
)
>
0
:
state_dict_new
[
k
]
=
v_new
state_dict_new
[
k
]
=
v_new
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
elif
hasattr
(
v
,
'dp_comm'
)
and
v
.
dp_comm
==
expert_dp_comm
:
state_dict_new
[
k
]
=
v
.
detach
()
state_dict_new
[
k
]
=
v
.
detach
()
return
state_dict_new
return
state_dict_new
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
'none'
)
state_dict
[
'model'
]
=
extract_expert_param
(
state_dict
[
'model'
],
expert_dp_comm
=
'none'
)
# Optimizer stuff.
# Optimizer stuff.
if
not
args
.
no_save_optim
:
if
not
args
.
no_save_optim
:
...
...
fmoe/transformer.py
View file @
c844413b
...
@@ -15,10 +15,8 @@ class _Expert(nn.Module):
...
@@ -15,10 +15,8 @@ class _Expert(nn.Module):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
def
__init__
(
self
,
num_expert
,
d_model
,
d_hidden
,
activation
,
rank
=
0
):
super
().
__init__
()
super
().
__init__
()
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
self
.
htoh4
=
FMoELinear
(
num_expert
,
d_model
,
d_hidden
,
bias
=
True
,
rank
=
rank
)
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
h4toh
=
FMoELinear
(
num_expert
,
d_hidden
,
d_model
,
bias
=
True
,
rank
=
rank
)
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
inp
,
fwd_expert_count
):
def
forward
(
self
,
inp
,
fwd_expert_count
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment