Unverified Commit 8b7d5359 authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

fix gptj (#5652)

parent 1b387ca9
......@@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
......@@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
suffix="attn.k_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
......@@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
suffix="attn.q_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
......@@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
suffix="attn.v_proj",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="attn.out_proj",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_in",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="mlp.fc_out",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment