Commit 0a8f3c85 authored by Jianghai's avatar Jianghai Committed by Hongxin Liu
Browse files

[hotfix] fix opt pipeline (#4293)

* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* opt forward and test

* pause

* finish opt model pipeline

* finish opt pipeline

* fix opt

* set transformers version

* refactor the test pipeline

* fix bug
parent d8408d18
...@@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -12,6 +12,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -198,8 +199,8 @@ class OPTForCausalLMPolicy(OPTPolicy):
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
opt_model = self.model opt_model = self.model
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
num_stages = self.pipeline_stage_manager.num_stages num_stages = self.pipeline_stage_manager.num_stages
if self.pipeline_stage_manager and num_stages > 1:
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment