Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d8ceeac1
Unverified
Commit
d8ceeac1
authored
Sep 12, 2023
by
Baizhou Zhang
Committed by
GitHub
Sep 12, 2023
Browse files
[hotfix] fix typo in hybrid parallel io (#4697)
parent
8844691f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
7 deletions
+7
-7
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+2
-2
colossalai/checkpoint_io/__init__.py
colossalai/checkpoint_io/__init__.py
+1
-1
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+4
-4
No files found.
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
d8ceeac1
...
...
@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
colossalai.amp.naive_amp.mixed_precision_optimizer
import
MixedPrecisionOptimizer
from
colossalai.checkpoint_io
import
CheckpointIO
,
Hy
p
ridParallelCheckpointIO
from
colossalai.checkpoint_io
import
CheckpointIO
,
Hy
b
ridParallelCheckpointIO
from
colossalai.cluster
import
ProcessGroupMesh
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.pipeline.schedule
import
OneForwardOneBackwardSchedule
...
...
@@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
**
_kwargs
)
def
get_checkpoint_io
(
self
)
->
CheckpointIO
:
self
.
checkpoint_io
=
Hy
p
ridParallelCheckpointIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
self
.
checkpoint_io
=
Hy
b
ridParallelCheckpointIO
(
self
.
dp_group
,
self
.
pp_group
,
self
.
tp_group
,
self
.
zero_stage
)
return
self
.
checkpoint_io
def
no_sync
(
self
,
model
:
Module
)
->
Iterator
[
None
]:
...
...
colossalai/checkpoint_io/__init__.py
View file @
d8ceeac1
from
.checkpoint_io_base
import
CheckpointIO
from
.general_checkpoint_io
import
GeneralCheckpointIO
from
.hybrid_parallel_checkpoint_io
import
Hy
p
ridParallelCheckpointIO
from
.hybrid_parallel_checkpoint_io
import
Hy
b
ridParallelCheckpointIO
from
.index_file
import
CheckpointIndexFile
__all__
=
[
'CheckpointIO'
,
'CheckpointIndexFile'
,
'GeneralCheckpointIO'
,
'HybridParallelCheckpointIO'
]
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
View file @
d8ceeac1
...
...
@@ -39,7 +39,7 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX
=
'_extra_state'
class
Hy
p
ridParallelCheckpointIO
(
GeneralCheckpointIO
):
class
Hy
b
ridParallelCheckpointIO
(
GeneralCheckpointIO
):
"""
CheckpointIO for Hybrid Parallel Training.
...
...
@@ -136,7 +136,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
param_id
=
param_info
[
'param2id'
][
id
(
working_param
)]
original_shape
=
param_info
[
'param2shape'
][
id
(
working_param
)]
state_
=
Hy
p
ridParallelCheckpointIO
.
gather_from_sharded_optimizer_state
(
state
,
state_
=
Hy
b
ridParallelCheckpointIO
.
gather_from_sharded_optimizer_state
(
state
,
working_param
,
original_shape
=
original_shape
,
dp_group
=
dp_group
,
...
...
@@ -189,7 +189,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard
=
Hy
p
ridParallelCheckpointIO
.
_model_sharder
(
model
,
size_per_shard
=
size_per_shard
)
state_dict_shard
=
Hy
b
ridParallelCheckpointIO
.
_model_sharder
(
model
,
size_per_shard
=
size_per_shard
)
weights_name
,
save_index_file
=
get_model_base_filenames
(
prefix
,
use_safetensors
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
control_saving
=
(
self
.
tp_rank
==
0
)
...
...
@@ -385,7 +385,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard
=
Hy
p
ridParallelCheckpointIO
.
_optimizer_sharder
(
state_dict_shard
=
Hy
b
ridParallelCheckpointIO
.
_optimizer_sharder
(
optimizer
,
use_zero
=
self
.
use_zero
,
dp_group
=
self
.
dp_group
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment