Unverified Commit 5e1c93d7 authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)


Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
parent a7ae2b5b
...@@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None: ...@@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
def auto_set_accelerator() -> None: def auto_set_accelerator() -> None:
""" """
Automatically check if any accelerator is available. Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator. If an accelerator is available, set it as the global accelerator.
""" """
global _ACCELERATOR global _ACCELERATOR
......
...@@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase):
) )
def __del__(self): def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh""" """Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups() self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
......
...@@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.max_norm = max_norm self.max_norm = max_norm
def __del__(self): def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh""" """Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups() self.pg_mesh.destroy_mesh_process_groups()
@property @property
......
...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import ( ...@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
) )
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpintIO from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
...@@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ...@@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs, **_kwargs,
) )
def get_checkpoint_io(self) -> MoECheckpintIO:
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None: if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else: else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io return self.checkpoint_io
......
...@@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
pp_group (ProcessGroup): Process group along pipeline parallel dimension. pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension. tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2]. zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True. verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True.
""" """
def __init__( def __init__(
...@@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group # obtain updated param group
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups}) optimizer.optim.__dict__.update({"param_groups": updated_groups})
......
from .checkpoint import MoECheckpintIO from .checkpoint import MoECheckpointIO
from .experts import MLPExperts from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER from .manager import MOE_MANAGER
...@@ -14,7 +14,7 @@ __all__ = [ ...@@ -14,7 +14,7 @@ __all__ = [
"NormalNoiseGenerator", "NormalNoiseGenerator",
"UniformNoiseGenerator", "UniformNoiseGenerator",
"SparseMLP", "SparseMLP",
"MoECheckpintIO", "MoECheckpointIO",
"MOE_MANAGER", "MOE_MANAGER",
"apply_load_balance", "apply_load_balance",
] ]
...@@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import ( ...@@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import (
) )
class MoECheckpintIO(HybridParallelCheckpointIO): class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__( def __init__(
self, self,
dp_group: ProcessGroup, dp_group: ProcessGroup,
...@@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO): ...@@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups): for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group # obtain updated param group
new_pg = copy.deepcopy(saved_pg) new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change. new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg) updated_groups.append(new_pg)
# ep param group # ep param group
if len(optimizer.optim.param_groups) > len(saved_groups): if len(optimizer.optim.param_groups) > len(saved_groups):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment