Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
527758b2
Unverified
Commit
527758b2
authored
Jul 29, 2022
by
HELSON
Committed by
GitHub
Jul 29, 2022
Browse files
[hotfix] fix a running error in test_colo_checkpoint.py (#1387)
parent
f792507f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
4 deletions
+13
-4
colossalai/gemini/chunk.py
colossalai/gemini/chunk.py
+1
-1
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+6
-0
colossalai/utils/checkpoint/utils.py
colossalai/utils/checkpoint/utils.py
+1
-1
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+5
-2
No files found.
colossalai/gemini/chunk.py
View file @
527758b2
...
...
@@ -208,7 +208,7 @@ class Chunk:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""
assert
tensor
!=
TensorState
.
FREE
,
'Can only set a chunk of tensors to FREE'
# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
...
...
colossalai/utils/checkpoint/module_checkpoint.py
View file @
527758b2
...
...
@@ -89,6 +89,12 @@ def load_checkpoint(path: str,
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
"""
# initialize the default paramters
if
not
torch_load_kwargs
:
torch_load_kwargs
=
dict
()
if
not
load_state_dict_kwargs
:
load_state_dict_kwargs
=
dict
()
rank
=
dist
.
get_rank
()
mapping
=
dict
()
for
n
,
p
in
model
.
named_parameters
():
...
...
colossalai/utils/checkpoint/utils.py
View file @
527758b2
...
...
@@ -24,7 +24,7 @@ def gather_tensor(colo_tensor: ColoTensor) -> None:
if
not
colo_tensor
.
is_replicate
():
pg
=
colo_tensor
.
get_process_group
()
# for the group which contains rank 0
if
pg
.
tp_rank_list
()[
0
]
==
0
:
if
pg
.
dp_local_rank
()
==
0
:
old_dist_spec
=
colo_tensor
.
dist_spec
colo_tensor
.
to_replicate_
()
if
dist
.
get_rank
()
!=
0
:
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
527758b2
...
...
@@ -146,6 +146,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
dist
.
broadcast
(
data
,
pg
.
tp_rank_list
()[
0
],
pg
.
tp_process_group
())
dist
.
broadcast
(
label
,
pg
.
tp_rank_list
()[
0
],
pg
.
tp_process_group
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
...
...
@@ -183,9 +186,9 @@ def _run_checkpoint(model_name, init_spec_func, use_ddp, use_mp_reload, test_sch
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
use_mp_reload
,
test_scheduler
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
# TODO(haichen) add BERT in the test
# the data loader of BERT is in DDP mode, causing the input data is not replicated in the TP context
for
model_name
in
[
'
simple_ne
t'
]:
for
model_name
in
[
'
ber
t'
]:
_run_checkpoint
(
model_name
,
init_1d_row_for_linear_weight_spec
,
use_ddp
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment