Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8b7d5359
Unverified
Commit
8b7d5359
authored
Apr 26, 2024
by
flybird11111
Committed by
GitHub
Apr 26, 2024
Browse files
fix gptj (#5652)
parent
1b387ca9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
7 deletions
+0
-7
colossalai/shardformer/policies/gptj.py
colossalai/shardformer/policies/gptj.py
+0
-7
No files found.
colossalai/shardformer/policies/gptj.py
View file @
8b7d5359
...
@@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
...
@@ -54,7 +54,6 @@ class GPTJPolicy(Policy):
if
self
.
shard_config
.
enable_sequence_parallelism
:
if
self
.
shard_config
.
enable_sequence_parallelism
:
self
.
shard_config
.
enable_sequence_parallelism
=
False
self
.
shard_config
.
enable_sequence_parallelism
=
False
warnings
.
warn
(
"GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
warnings
.
warn
(
"GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag."
)
use_sequence_parallel
=
self
.
shard_config
.
enable_sequence_parallelism
overlap
=
self
.
shard_config
.
enable_sequence_overlap
overlap
=
self
.
shard_config
.
enable_sequence_overlap
if
self
.
shard_config
.
enable_tensor_parallelism
:
if
self
.
shard_config
.
enable_tensor_parallelism
:
...
@@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
...
@@ -78,7 +77,6 @@ class GPTJPolicy(Policy):
suffix
=
"attn.k_proj"
,
suffix
=
"attn.k_proj"
,
target_module
=
col_nn
.
Linear1D_Col
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
,
"overlap"
:
overlap
,
"overlap"
:
overlap
,
},
},
),
),
...
@@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
...
@@ -86,7 +84,6 @@ class GPTJPolicy(Policy):
suffix
=
"attn.q_proj"
,
suffix
=
"attn.q_proj"
,
target_module
=
col_nn
.
Linear1D_Col
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
,
"overlap"
:
overlap
,
"overlap"
:
overlap
,
},
},
),
),
...
@@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
...
@@ -94,24 +91,20 @@ class GPTJPolicy(Policy):
suffix
=
"attn.v_proj"
,
suffix
=
"attn.v_proj"
,
target_module
=
col_nn
.
Linear1D_Col
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
,
"overlap"
:
overlap
,
"overlap"
:
overlap
,
},
},
),
),
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"attn.out_proj"
,
suffix
=
"attn.out_proj"
,
target_module
=
col_nn
.
Linear1D_Row
,
target_module
=
col_nn
.
Linear1D_Row
,
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
},
),
),
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"mlp.fc_in"
,
suffix
=
"mlp.fc_in"
,
target_module
=
col_nn
.
Linear1D_Col
,
target_module
=
col_nn
.
Linear1D_Col
,
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
},
),
),
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"mlp.fc_out"
,
suffix
=
"mlp.fc_out"
,
target_module
=
col_nn
.
Linear1D_Row
,
target_module
=
col_nn
.
Linear1D_Row
,
kwargs
=
{
"seq_parallel"
:
use_sequence_parallel
},
),
),
SubModuleReplacementDescription
(
SubModuleReplacementDescription
(
suffix
=
"attn.attn_dropout"
,
suffix
=
"attn.attn_dropout"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment