Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a4cec171
Unverified
Commit
a4cec171
authored
Feb 05, 2024
by
Hongxin Liu
Committed by
GitHub
Feb 05, 2024
Browse files
[llama] add flash attn patch for npu (#5362)
parent
73f9f23f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
321 additions
and
181 deletions
+321
-181
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
...al-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
+321
-181
No files found.
applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py
View file @
a4cec171
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
math
from
types
import
MethodType
from
types
import
MethodType
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_kvpacked_func
from
flash_attn.ops.rms_norm
import
rms_norm
from
transformers.models.llama.modeling_llama
import
(
from
transformers.models.llama.modeling_llama
import
(
LlamaAttention
,
LlamaAttention
,
LlamaForCausalLM
,
LlamaForCausalLM
,
...
@@ -19,18 +19,23 @@ from transformers.models.llama.modeling_llama import (
...
@@ -19,18 +19,23 @@ from transformers.models.llama.modeling_llama import (
repeat_kv
,
repeat_kv
,
)
)
from
colossalai.accelerator
import
get_accelerator
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
if
get_accelerator
().
name
==
"cuda"
:
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
flash_attn_func
,
flash_attn_varlen_kvpacked_func
from
flash_attn.ops.rms_norm
import
rms_norm
def
_prepare_decoder_attention_mask
(
def
_prepare_decoder_attention_mask
(
self
:
LlamaModel
,
self
:
LlamaModel
,
attention_mask
:
torch
.
BoolTensor
,
attention_mask
:
torch
.
BoolTensor
,
input_shape
:
torch
.
Size
,
input_shape
:
torch
.
Size
,
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
past_key_values_length
:
int
,
past_key_values_length
:
int
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Decoder attetion mask
Decoder attetion mask
"""
"""
...
@@ -51,8 +56,7 @@ def _prepare_decoder_attention_mask(
...
@@ -51,8 +56,7 @@ def _prepare_decoder_attention_mask(
return
None
# Faster
return
None
# Faster
return
attention_mask
return
attention_mask
def
attention_forward
(
def
attention_forward
(
self
:
LlamaAttention
,
self
:
LlamaAttention
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -61,7 +65,7 @@ def attention_forward(
...
@@ -61,7 +65,7 @@ def attention_forward(
output_attentions
:
bool
=
False
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
"""
"""
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
"""
"""
...
@@ -163,7 +167,9 @@ def attention_forward(
...
@@ -163,7 +167,9 @@ def attention_forward(
if
key_padding_mask
is
None
:
if
key_padding_mask
is
None
:
# (bsz, past_kv_len + q_len, num_heads, head_dim)
# (bsz, past_kv_len + q_len, num_heads, head_dim)
output
=
flash_attn_func
(
q
=
q
,
k
=
k
,
v
=
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
True
)
# (bsz, )
output
=
flash_attn_func
(
q
=
q
,
k
=
k
,
v
=
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
True
)
# (bsz, )
output
=
rearrange
(
output
,
pattern
=
"... h d -> ... (h d)"
)
# (bsz, past_kv_len + q_len, num_heads * head_dim)
output
=
rearrange
(
output
,
pattern
=
"... h d -> ... (h d)"
)
# (bsz, past_kv_len + q_len, num_heads * head_dim)
else
:
else
:
q
,
indices
,
cu_q_lens
,
max_q_len
=
unpad_input
(
hidden_states
=
q
,
attention_mask
=
key_padding_mask
)
q
,
indices
,
cu_q_lens
,
max_q_len
=
unpad_input
(
hidden_states
=
q
,
attention_mask
=
key_padding_mask
)
kv
,
_
,
cu_kv_lens
,
max_kv_len
=
unpad_input
(
kv
,
_
,
cu_kv_lens
,
max_kv_len
=
unpad_input
(
...
@@ -194,15 +200,13 @@ def attention_forward(
...
@@ -194,15 +200,13 @@ def attention_forward(
output
=
self
.
o_proj
(
output
)
# (bsz, q_len, hidden_size)
output
=
self
.
o_proj
(
output
)
# (bsz, q_len, hidden_size)
return
output
,
None
,
past_key_value
return
output
,
None
,
past_key_value
def
rms_norm_forward
(
self
:
LlamaRMSNorm
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rms_norm_forward
(
self
:
LlamaRMSNorm
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Formard function for RMS Norm
Formard function for RMS Norm
"""
"""
return
rms_norm
(
x
=
hidden_states
,
weight
=
self
.
weight
,
epsilon
=
self
.
variance_epsilon
)
return
rms_norm
(
x
=
hidden_states
,
weight
=
self
.
weight
,
epsilon
=
self
.
variance_epsilon
)
def
replace_with_flash_attention
(
model
:
LlamaForCausalLM
)
->
None
:
def
replace_with_flash_attention
(
model
:
LlamaForCausalLM
)
->
None
:
for
name
,
module
in
model
.
named_modules
():
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LlamaAttention
):
if
isinstance
(
module
,
LlamaAttention
):
module
.
forward
=
MethodType
(
attention_forward
,
module
)
module
.
forward
=
MethodType
(
attention_forward
,
module
)
...
@@ -210,3 +214,139 @@ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
...
@@ -210,3 +214,139 @@ def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
module
.
_prepare_decoder_attention_mask
=
MethodType
(
_prepare_decoder_attention_mask
,
module
)
module
.
_prepare_decoder_attention_mask
=
MethodType
(
_prepare_decoder_attention_mask
,
module
)
if
isinstance
(
module
,
LlamaRMSNorm
):
if
isinstance
(
module
,
LlamaRMSNorm
):
module
.
forward
=
MethodType
(
rms_norm_forward
,
module
)
module
.
forward
=
MethodType
(
rms_norm_forward
,
module
)
elif
get_accelerator
().
name
==
"npu"
:
import
torch_npu
class
NPULlamaAttention
(
LlamaAttention
):
use_flash
:
bool
=
True
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
(
config
)
self
.
setup
()
def
setup
(
self
):
self
.
_softmax_scale
=
1
/
math
.
sqrt
(
self
.
head_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
if
self
.
config
.
pretraining_tp
>
1
:
key_value_slicing
=
(
self
.
num_key_value_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
query_slices
=
self
.
q_proj
.
weight
.
split
(
(
self
.
num_heads
*
self
.
head_dim
)
//
self
.
config
.
pretraining_tp
,
dim
=
0
)
key_slices
=
self
.
k_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
value_slices
=
self
.
v_proj
.
weight
.
split
(
key_value_slicing
,
dim
=
0
)
query_states
=
[
F
.
linear
(
hidden_states
,
query_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
query_states
=
torch
.
cat
(
query_states
,
dim
=-
1
)
key_states
=
[
F
.
linear
(
hidden_states
,
key_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
key_states
=
torch
.
cat
(
key_states
,
dim
=-
1
)
value_states
=
[
F
.
linear
(
hidden_states
,
value_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
value_states
=
torch
.
cat
(
value_states
,
dim
=-
1
)
else
:
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
if
not
self
.
use_flash
:
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention weights should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
attn_weights
=
attn_weights
+
attention_mask
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
else
:
attn_output
,
*
_
=
torch_npu
.
npu_fusion_attention
(
query_states
,
key_states
,
value_states
,
self
.
num_heads
,
"BNSD"
,
atten_mask
=
attention_mask
.
bool
(),
scale
=
self
.
_softmax_scale
,
padding_mask
=
None
,
pre_tockens
=
65535
,
next_tockens
=
0
,
keep_prob
=
1.0
,
inner_precise
=
0
,
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
if
self
.
config
.
pretraining_tp
>
1
:
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
o_proj_slices
=
self
.
o_proj
.
weight
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
1
)
attn_output
=
sum
(
[
F
.
linear
(
attn_output
[
i
],
o_proj_slices
[
i
])
for
i
in
range
(
self
.
config
.
pretraining_tp
)]
)
else
:
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
NPURMSNorm
(
LlamaRMSNorm
):
def
forward
(
self
,
hidden_states
):
return
torch_npu
.
npu_rms_norm
(
hidden_states
,
self
.
weight
,
epsilon
=
self
.
variance_epsilon
)[
0
]
def
replace_with_flash_attention
(
model
:
LlamaForCausalLM
)
->
None
:
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
LlamaAttention
):
module
.
__class__
=
NPULlamaAttention
module
.
setup
()
if
isinstance
(
module
,
LlamaRMSNorm
):
module
.
__class__
=
NPURMSNorm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment