"vscode:/vscode.git/clone" did not exist on "a0436a62ee4ec65b565423ae34aeae28b0d689d0"
Unverified Commit 351351a3 authored by Bin Jia's avatar Bin Jia Committed by GitHub
Browse files

[shardformer/sequence parallel] not support opt of seq-parallel, add warning...

[shardformer/sequence parallel] not support opt of seq-parallel, add warning and fix a bug in gpt2 pp (#4488)
parent 5545114f
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List from typing import Callable, Dict, List
...@@ -39,6 +40,9 @@ class OPTPolicy(Policy): ...@@ -39,6 +40,9 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {} policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[ policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment