Unverified Commit 2a2ec49a authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)

* fix 3d checkpoint load when booster boost without optimizer

fix 3d checkpoint load when booster boost without optimizer

* test ci

* revert ci

* fix

fix
parent f6731db6
...@@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
...@@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16): ...@@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
return x return x
class HybridParallelModule(ModelWrapper): class HybridParallelModule(ModelWrapper, AMPModelMixin):
def __init__( def __init__(
self, self,
module: Module, module: Module,
......
...@@ -116,6 +116,9 @@ def check_gemini_plugin( ...@@ -116,6 +116,9 @@ def check_gemini_plugin(
"transformers_falcon_for_sequence_classification", "transformers_falcon_for_sequence_classification",
"transformers_falcon_for_token_classification", "transformers_falcon_for_token_classification",
"transformers_falcon_for_question_answering", "transformers_falcon_for_question_answering",
"transformers_gptj_lm", # lead to OOM when running in ci
"transformers_gptj_for_question_answering",
"transformers_gptj_for_sequence_classification",
]: ]:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment