Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b9b469ea
Unverified
Commit
b9b469ea
authored
Apr 12, 2022
by
HELSON
Committed by
GitHub
Apr 12, 2022
Browse files
[moe] add checkpoint for moe zero test (#729)
parent
6f7d1362
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
9 deletions
+10
-9
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+5
-4
tests/test_moe/test_moe_zero_model.py
tests/test_moe/test_moe_zero_model.py
+2
-2
tests/test_moe/test_moe_zero_optim.py
tests/test_moe/test_moe_zero_optim.py
+3
-3
No files found.
tests/test_moe/test_moe_zero_init.py
View file @
b9b469ea
...
...
@@ -5,6 +5,7 @@ import pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.nn
import
CheckpointModule
from
colossalai.logging
import
get_dist_logger
from
colossalai.testing
import
parameterize
from
colossalai.utils
import
free_port
...
...
@@ -18,10 +19,10 @@ from colossalai.utils import get_current_device
from
tests.test_zero_data_parallel.common
import
CONFIG
class
MoeModel
(
nn
.
Module
):
class
MoeModel
(
Checkpoint
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
__init__
(
self
,
checkpoint
:
bool
=
False
):
super
().
__init__
(
checkpoint
)
self
.
proj1
=
nn
.
Linear
(
4
,
16
)
expert_cls
=
nn
.
Linear
expert_args_dict
=
dict
(
in_features
=
16
,
out_features
=
16
)
...
...
@@ -52,7 +53,7 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
shard_strategy
=
shard_strategy_class
(),
shard_param
=
True
,
model_numel_tensor
=
model_numel_tensor
):
model
=
MoeModel
()
model
=
MoeModel
(
checkpoint
=
True
)
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'colo_attr'
)
...
...
tests/test_moe/test_moe_zero_model.py
View file @
b9b469ea
...
...
@@ -31,7 +31,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
()),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
=
MoeModel
()
zero_model
=
MoeModel
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
use_memory_tracer
=
True
)
# check whether parameters are identical in ddp
...
...
@@ -39,7 +39,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
colo_attr
.
is_replicated
:
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
)
model
=
MoeModel
().
half
()
model
=
MoeModel
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
()
grad_handler
=
MoeGradientHandler
(
model
)
...
...
tests/test_moe/test_moe_zero_optim.py
View file @
b9b469ea
...
...
@@ -65,7 +65,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
with
ZeroInitContext
(
target_device
=
torch
.
device
(
'cpu'
)
if
cpu_offload
else
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
=
MoeModel
()
zero_model
=
MoeModel
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
shard_strategy
,
...
...
@@ -78,7 +78,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
if
not
p
.
colo_attr
.
param_is_sharded
and
p
.
colo_attr
.
is_replicated
:
assert_equal_in_group
(
p
.
colo_attr
.
sharded_data_tensor
.
payload
.
to
(
get_current_device
()))
model
=
MoeModel
().
half
()
model
=
MoeModel
(
checkpoint
=
True
).
half
()
col_model_deepcopy
(
zero_model
,
model
)
model
=
model
.
cuda
().
float
()
...
...
@@ -129,4 +129,4 @@ def test_moe_zero_optim(world_size):
if
__name__
==
'__main__'
:
test_moe_zero_optim
(
world_size
=
2
)
test_moe_zero_optim
(
world_size
=
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment