Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d882d18c
Unverified
Commit
d882d18c
authored
Feb 27, 2024
by
Hongxin Liu
Committed by
GitHub
Feb 27, 2024
Browse files
[example] reuse flash attn patch (#5400)
parent
95c21e39
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
8 additions
and
177 deletions
+8
-177
examples/language/llama2/attn.py
examples/language/llama2/attn.py
+1
-84
examples/language/llama2/attn.py
examples/language/llama2/attn.py
+1
-84
examples/language/llama2/benchmark.py
examples/language/llama2/benchmark.py
+2
-3
examples/language/llama2/finetune.py
examples/language/llama2/finetune.py
+2
-3
examples/language/llama2/pretrain.py
examples/language/llama2/pretrain.py
+2
-3
No files found.
examples/language/llama2/attn.py
deleted
100644 → 0
View file @
95c21e39
from
types
import
MethodType
from
typing
import
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
apply_rotary_pos_emb
,
repeat_kv
SUPPORT_XFORMERS
=
False
SUPPORT_FLASH2
=
False
try
:
import
xformers.ops
as
xops
SUPPORT_XFORMERS
=
True
except
ImportError
:
pass
try
:
from
flash_attn
import
flash_attn_func
SUPPORT_FLASH2
=
True
except
ImportError
:
pass
SUPPORT_FLASH
=
SUPPORT_XFORMERS
or
SUPPORT_FLASH2
def
llama_flash_attention
(
self
:
LlamaAttention
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
self
.
k_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
# repeat k/v heads if n_kv_heads < n_heads
key_states
=
repeat_kv
(
key_states
,
self
.
num_key_value_groups
)
value_states
=
repeat_kv
(
value_states
,
self
.
num_key_value_groups
)
# q, k, v is [B, H, S, K] and xformers need [B, S, H, K]. returns [B, S, H, K]
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
if
SUPPORT_FLASH2
:
attn_output
=
flash_attn_func
(
query_states
,
key_states
,
value_states
,
causal
=
True
)
else
:
attn_output
=
xops
.
memory_efficient_attention
(
query_states
,
key_states
,
value_states
,
attn_bias
=
xops
.
LowerTriangularMask
()
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
def
replace_xformers
(
model
:
nn
.
Module
):
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LlamaAttention
):
module
.
forward
=
MethodType
(
llama_flash_attention
,
module
)
examples/language/llama2/attn.py
0 → 120000
View file @
d882d18c
..
/
..
/
..
/
applications
/
Colossal
-
LLaMA
-
2
/
colossal_llama2
/
utils
/
flash_attention_patch
.
py
\ No newline at end of file
examples/language/llama2/benchmark.py
View file @
d882d18c
...
@@ -3,7 +3,7 @@ import resource
...
@@ -3,7 +3,7 @@ import resource
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
import
torch
import
torch
from
attn
import
SUPPORT_FLASH
,
replace_xformers
from
attn
import
replace_with_flash_attention
from
data_utils
import
RandomDataset
from
data_utils
import
RandomDataset
from
model_utils
import
format_numel_str
,
get_model_numel
from
model_utils
import
format_numel_str
,
get_model_numel
from
performance_evaluator
import
PerformanceEvaluator
from
performance_evaluator
import
PerformanceEvaluator
...
@@ -188,8 +188,7 @@ def main():
...
@@ -188,8 +188,7 @@ def main():
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
if
args
.
xformers
:
if
args
.
xformers
:
assert
SUPPORT_FLASH
,
"Use flash attention while xfomers is not installed"
replace_with_flash_attention
(
model
)
replace_xformers
(
model
)
model_numel
=
get_model_numel
(
model
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
...
...
examples/language/llama2/finetune.py
View file @
d882d18c
...
@@ -9,7 +9,7 @@ from typing import Optional, Tuple
...
@@ -9,7 +9,7 @@ from typing import Optional, Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
attn
import
SUPPORT_XFORMERS
,
replace_xformers
from
attn
import
replace_with_flash_attention
from
data_utils
import
load_json
,
prepare_dataloader
,
save_json
from
data_utils
import
load_json
,
prepare_dataloader
,
save_json
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -219,8 +219,7 @@ def main():
...
@@ -219,8 +219,7 @@ def main():
if
args
.
grad_checkpoint
:
if
args
.
grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
if
args
.
flash_attention
:
if
args
.
flash_attention
:
assert
SUPPORT_XFORMERS
,
"Use flash attention while xfomers is not installed"
replace_with_flash_attention
(
model
)
replace_xformers
(
model
)
model_numel
=
get_model_numel
(
model
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
...
...
examples/language/llama2/pretrain.py
View file @
d882d18c
...
@@ -8,7 +8,7 @@ from typing import Optional, Tuple
...
@@ -8,7 +8,7 @@ from typing import Optional, Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
attn
import
SUPPORT_XFORMERS
,
replace_xformers
from
attn
import
replace_with_flash_attention
from
data_utils
import
load_json
,
prepare_dataloader
,
save_json
from
data_utils
import
load_json
,
prepare_dataloader
,
save_json
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
...
@@ -238,8 +238,7 @@ def main():
...
@@ -238,8 +238,7 @@ def main():
if
args
.
grad_checkpoint
:
if
args
.
grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
if
args
.
flash_attention
:
if
args
.
flash_attention
:
assert
SUPPORT_XFORMERS
,
"Use flash attention while xfomers is not installed"
replace_with_flash_attention
(
model
)
replace_xformers
(
model
)
model_numel
=
get_model_numel
(
model
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment