Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
97cd0cd5
Unverified
Commit
97cd0cd5
authored
Nov 16, 2023
by
flybird11111
Committed by
GitHub
Nov 16, 2023
Browse files
[shardformer] fix llama error when transformers upgraded. (#5055)
* fix-llama * Update llama.py
parent
3e021547
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
4 deletions
+14
-4
colossalai/shardformer/modeling/llama.py
colossalai/shardformer/modeling/llama.py
+14
-4
No files found.
colossalai/shardformer/modeling/llama.py
View file @
97cd0cd5
import
warnings
import
warnings
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
...
@@ -13,6 +13,11 @@ from transformers.utils import logging
...
@@ -13,6 +13,11 @@ from transformers.utils import logging
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
try
:
from
transformers.models.llama.modeling_llama
import
_prepare_4d_causal_attention_mask
LATEST_VERSION
=
True
except
ImportError
:
LATEST_VERSION
=
False
class
LlamaPipelineForwards
:
class
LlamaPipelineForwards
:
"""
"""
...
@@ -97,9 +102,14 @@ class LlamaPipelineForwards:
...
@@ -97,9 +102,14 @@ class LlamaPipelineForwards:
attention_mask
=
torch
.
ones
(
attention_mask
=
torch
.
ones
(
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
hidden_states
.
device
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
hidden_states
.
device
)
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
if
LATEST_VERSION
:
attention_mask
,
(
batch_size
,
seq_length
),
hidden_states
,
past_key_values_length
attention_mask
=
_prepare_4d_causal_attention_mask
(
)
attention_mask
,
(
batch_size
,
seq_length
),
hidden_states
,
past_key_values_length
)
else
:
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
hidden_states
,
past_key_values_length
)
if
self
.
gradient_checkpointing
and
self
.
training
:
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
if
use_cache
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment