Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb3a25a0
Unverified
Commit
cb3a25a0
authored
Oct 07, 2023
by
Hongxin Liu
Committed by
GitHub
Oct 07, 2023
Browse files
[checkpointio] hotfix torch 2.0 compatibility (#4824)
parent
ad23460c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
12 deletions
+26
-12
colossalai/checkpoint_io/utils.py
colossalai/checkpoint_io/utils.py
+5
-1
colossalai/zero/gemini/gemini_optimizer.py
colossalai/zero/gemini/gemini_optimizer.py
+5
-1
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+16
-10
No files found.
colossalai/checkpoint_io/utils.py
View file @
cb3a25a0
...
@@ -9,6 +9,7 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
...
@@ -9,6 +9,7 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
packaging.version
import
Version
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.tensor.d_tensor
import
(
from
colossalai.tensor.d_tensor
import
(
...
@@ -663,7 +664,10 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
...
@@ -663,7 +664,10 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
"""
# Do the cleaning up as in src code of Pytorch.
# Do the cleaning up as in src code of Pytorch.
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.0.0"
):
optimizer
.
_patch_step_function
()
# To support multiprocessing pickle/unpickle
else
:
optimizer
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
optimizer
.
defaults
.
setdefault
(
"differentiable"
,
False
)
optimizer
.
defaults
.
setdefault
(
"differentiable"
,
False
)
...
...
colossalai/zero/gemini/gemini_optimizer.py
View file @
cb3a25a0
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
...
@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
packaging.version
import
Version
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -676,7 +677,10 @@ class GeminiOptimizer(OptimizerWrapper):
...
@@ -676,7 +677,10 @@ class GeminiOptimizer(OptimizerWrapper):
def
optimizer_loading_epilogue
(
self
):
def
optimizer_loading_epilogue
(
self
):
# Epilogue when loading state_dict to pytorch optimizer.
# Epilogue when loading state_dict to pytorch optimizer.
self
.
optim
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
if
Version
(
torch
.
__version__
)
>=
Version
(
"2.0.0"
):
self
.
optim
.
_patch_step_function
()
# To support multiprocessing pickle/unpickle
else
:
self
.
optim
.
_hook_for_profile
()
# To support multiprocessing pickle/unpickle.
self
.
optim
.
defaults
.
setdefault
(
"differentiable"
,
False
)
self
.
optim
.
defaults
.
setdefault
(
"differentiable"
,
False
)
def
load_state_dict
(
self
,
state_dict
:
dict
):
def
load_state_dict
(
self
,
state_dict
:
dict
):
...
...
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
cb3a25a0
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
packaging.version
import
Version
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
utils
import
shared_tempdir
from
utils
import
shared_tempdir
...
@@ -19,14 +20,8 @@ from colossalai.testing import (
...
@@ -19,14 +20,8 @@ from colossalai.testing import (
)
)
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
if
Version
(
torch
.
__version__
)
<
Version
(
"2.0.0"
):
@
clear_cache_before_run
()
TEST_CONFIGS
=
[
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_gpt"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
[
{
{
"tp_size"
:
4
,
"tp_size"
:
4
,
"pp_size"
:
1
,
"pp_size"
:
1
,
...
@@ -35,8 +30,19 @@ from tests.kit.model_zoo import model_zoo
...
@@ -35,8 +30,19 @@ from tests.kit.model_zoo import model_zoo
{
"tp_size"
:
2
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
2
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
2
,
"pp_size"
:
1
,
"zero_stage"
:
2
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
2
,
"pp_size"
:
1
,
"zero_stage"
:
2
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
1
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"zero_stage"
:
1
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
{
"tp_size"
:
1
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"zero_stage"
:
1
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
],
]
)
else
:
TEST_CONFIGS
=
[
# TODO(ver217): other configs lead to hang
{
"tp_size"
:
1
,
"pp_size"
:
2
,
"num_microbatches"
:
4
,
"zero_stage"
:
1
,
"precision"
:
"fp16"
,
"initial_scale"
:
1
},
]
@
clear_cache_before_run
()
@
parameterize
(
"shard"
,
[
True
,
False
])
@
parameterize
(
"model_name"
,
[
"transformers_gpt"
])
@
parameterize
(
"size_per_shard"
,
[
32
])
@
parameterize
(
"test_config"
,
TEST_CONFIGS
)
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
def
exam_state_dict
(
shard
:
bool
,
model_name
:
str
,
size_per_shard
:
int
,
test_config
:
dict
):
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
loss_fn
,
_
)
=
next
(
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
iter
(
model_zoo
.
get_sub_registry
(
model_name
).
values
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment