Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb3a25a0
Unverified
Commit
cb3a25a0
authored
Oct 07, 2023
by
Hongxin Liu
Committed by
GitHub
Oct 07, 2023
Browse files
[checkpointio] hotfix torch 2.0 compatibility (#4824)
parent
ad23460c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
12 deletions
+26
-12
colossalai/checkpoint_io/utils.py
colossalai/checkpoint_io/utils.py
+5
-1
colossalai/zero/gemini/gemini_optimizer.py
colossalai/zero/gemini/gemini_optimizer.py
+5
-1
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+16
-10
No files found.
colossalai/checkpoint_io/utils.py
View file @
cb3a25a0
...
...
@@ -9,6 +9,7 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import
torch
import
torch.nn
as
nn
from
packaging.version
import
Version
from
torch.optim
import
Optimizer
from
colossalai.tensor.d_tensor
import
(
...
...
@@ -663,7 +664,10 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.0.0"
):
optimizer
.
_patch_step_function
()
# To support multiprocessing pickle/unpickle
else
:
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
optimizer
.
defaults
.
setdefault
(
"differentiable"
,
False
)
...
...
colossalai/zero/gemini/gemini_optimizer.py
View file @
cb3a25a0
...
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import
torch
import
torch.distributed
as
dist
from
packaging.version
import
Version
from
torch.nn
import
Parameter
from
torch.optim
import
Optimizer
...
...
@@ -676,7 +677,10 @@ class GeminiOptimizer(OptimizerWrapper):
def
optimizer_loading_epilogue
(
self
):
# Epilogue when loading state_dict to pytorch optimizer.
self
.
optim
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.0.0"
):
self
.
optim
.
_patch_step_function
()
# To support multiprocessing pickle/unpickle
else
:
self
.
optim
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
self
.
optim
.
defaults
.
setdefault
(
"differentiable"
,
False
)
def
load_state_dict
(
self
,
state_dict
:
dict
):
...
...
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
cb3a25a0
import
pytest
import
torch
import
torch.distributed
as
dist
from
packaging.version
import
Version
from
torch.optim
import
Adam
from
utils
import
shared_tempdir
...
...
@@ -19,14 +20,8 @@ from colossalai.testing import (
)
from
tests.kit.model_zoo
import
model_zoo
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_gpt"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
[
if
Version
(
torch
.
__version__
)
<
Version
(
"2.0.0"
):
TEST_CONFIGS
=
[
{
"tp_size"
:
4
,
"pp_size"
:
1
,
...
...
@@ -35,8 +30,19 @@ from tests.kit.model_zoo import model_zoo
{
"tp_size"
:
2
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
2
,
"pp_size"
:
1
,
"zero_stage"
:
2
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
1
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"zero_stage"
:
1
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
],
)
]
else
:
TEST_CONFIGS
=
[
# TODO(ver217): other configs lead to hang
{
"tp_size"
:
1
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"zero_stage"
:
1
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
]
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_gpt"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
TEST_CONFIGS
)
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment