Unverified Commit 5e1c93d7 authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[hotfix] fix typo change MoECheckpintIO to MoECheckpointIO (#5335)


Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
parent a7ae2b5b
......@@ -40,7 +40,7 @@ def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
def auto_set_accelerator() -> None:
"""
Automatically check if any accelerator is available.
If an accelerator is availabe, set it as the global accelerator.
If an accelerator is available, set it as the global accelerator.
"""
global _ACCELERATOR
......
......@@ -437,7 +437,7 @@ class GeminiPlugin(DPPluginBase):
)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def support_no_sync(self) -> bool:
......
......@@ -1067,7 +1067,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.max_norm = max_norm
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
"""Destroy the process groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
@property
......
......@@ -22,7 +22,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
)
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.moe import MOE_MANAGER, MoECheckpintIO
from colossalai.moe import MOE_MANAGER, MoECheckpointIO
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
......@@ -341,9 +341,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
**_kwargs,
)
def get_checkpoint_io(self) -> MoECheckpintIO:
def get_checkpoint_io(self) -> MoECheckpointIO:
if self.checkpoint_io is None:
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
self.checkpoint_io = MoECheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
else:
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
......
......@@ -51,7 +51,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
verbose (bool, optional): Whether to print logging massage when saving/loading has been successfully executed. Defaults to True.
"""
def __init__(
......@@ -574,7 +574,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({"param_groups": updated_groups})
......
from .checkpoint import MoECheckpintIO
from .checkpoint import MoECheckpointIO
from .experts import MLPExperts
from .layers import SparseMLP, apply_load_balance
from .manager import MOE_MANAGER
......@@ -14,7 +14,7 @@ __all__ = [
"NormalNoiseGenerator",
"UniformNoiseGenerator",
"SparseMLP",
"MoECheckpintIO",
"MoECheckpointIO",
"MOE_MANAGER",
"apply_load_balance",
]
......@@ -40,7 +40,7 @@ from colossalai.tensor.moe_tensor.api import (
)
class MoECheckpintIO(HybridParallelCheckpointIO):
class MoECheckpointIO(HybridParallelCheckpointIO):
def __init__(
self,
dp_group: ProcessGroup,
......@@ -373,7 +373,7 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
new_pg["params"] = old_pg["params"] # The parameters in the same group shouldn't change.
updated_groups.append(new_pg)
# ep param group
if len(optimizer.optim.param_groups) > len(saved_groups):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment