Unverified Commit d8ceeac1 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[hotfix] fix typo in hybrid parallel io (#4697)

parent 8844691f
...@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader ...@@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
...@@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -513,7 +513,7 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs) **_kwargs)
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]: def no_sync(self, model: Module) -> Iterator[None]:
......
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO'] __all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']
...@@ -39,7 +39,7 @@ except ImportError: ...@@ -39,7 +39,7 @@ except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state' _EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class HypridParallelCheckpointIO(GeneralCheckpointIO): class HybridParallelCheckpointIO(GeneralCheckpointIO):
""" """
CheckpointIO for Hybrid Parallel Training. CheckpointIO for Hybrid Parallel Training.
...@@ -136,7 +136,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -136,7 +136,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
param_id = param_info['param2id'][id(working_param)] param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)] original_shape = param_info['param2shape'][id(working_param)]
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state, state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param, working_param,
original_shape=original_shape, original_shape=original_shape,
dp_group=dp_group, dp_group=dp_group,
...@@ -189,7 +189,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -189,7 +189,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group. # Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving. # Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint) index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0) control_saving = (self.tp_rank == 0)
...@@ -385,7 +385,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -385,7 +385,7 @@ class HypridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group. # Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder( state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer, optimizer,
use_zero=self.use_zero, use_zero=self.use_zero,
dp_group=self.dp_group, dp_group=self.dp_group,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment