Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
451e9142
Unverified
Commit
451e9142
authored
Jan 03, 2024
by
flybird11111
Committed by
GitHub
Jan 03, 2024
Browse files
fix flash attn (#5209)
parent
365671be
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
5 deletions
+6
-5
colossalai/shardformer/modeling/llama.py
colossalai/shardformer/modeling/llama.py
+3
-4
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+3
-1
No files found.
colossalai/shardformer/modeling/llama.py
View file @
451e9142
...
@@ -414,7 +414,7 @@ class LlamaPipelineForwards:
...
@@ -414,7 +414,7 @@ class LlamaPipelineForwards:
return
{
"hidden_states"
:
hidden_states
}
return
{
"hidden_states"
:
hidden_states
}
def
get_llama_flash_attention_forward
():
def
get_llama_flash_attention_forward
(
shard_config
:
ShardConfig
):
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
apply_rotary_pos_emb
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
apply_rotary_pos_emb
from
colossalai.kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.kernel.cuda_native
import
AttnMaskType
,
ColoAttention
...
@@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
...
@@ -470,14 +470,13 @@ def get_llama_flash_attention_forward():
flash_attention_mask
=
None
flash_attention_mask
=
None
attn_mask_type
=
AttnMaskType
.
causal
attn_mask_type
=
AttnMaskType
.
causal
if
attention_mask
!=
None
:
if
not
getattr
(
shard_config
,
"causal_lm"
,
False
)
and
attention_mask
!=
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
)
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
flash_attention_mask
=
~
(
attention_mask
[:,
:,
-
1
].
squeeze
(
1
).
to
(
torch
.
bool
)).
contiguous
()
if
not
torch
.
all
(
flash_attention_mask
):
attn_mask_type
=
AttnMaskType
.
paddedcausal
attn_mask_type
=
AttnMaskType
.
paddedcausal
attention
=
ColoAttention
(
embed_dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
)
attention
=
ColoAttention
(
embed_dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
)
attn_output
=
attention
(
attn_output
=
attention
(
...
...
colossalai/shardformer/policies/llama.py
View file @
451e9142
...
@@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
...
@@ -130,7 +130,7 @@ class LlamaPolicy(Policy):
if
self
.
shard_config
.
enable_flash_attention
:
if
self
.
shard_config
.
enable_flash_attention
:
self
.
append_or_create_method_replacement
(
self
.
append_or_create_method_replacement
(
description
=
{
description
=
{
"forward"
:
get_llama_flash_attention_forward
(),
"forward"
:
get_llama_flash_attention_forward
(
self
.
shard_config
),
},
},
policy
=
policy
,
policy
=
policy
,
target_key
=
LlamaAttention
,
target_key
=
LlamaAttention
,
...
@@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
...
@@ -250,6 +250,8 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
policy
=
super
().
module_policy
()
policy
=
super
().
module_policy
()
setattr
(
self
.
shard_config
,
"causal_lm"
,
True
)
if
self
.
shard_config
.
enable_tensor_parallelism
:
if
self
.
shard_config
.
enable_tensor_parallelism
:
# add a new item for casual lm
# add a new item for casual lm
new_item
=
{
new_item
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment