Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3f7d5786
Commit
3f7d5786
authored
Dec 24, 2023
by
Tri Dao
Browse files
Pass alibi slopes to flash_attn_with_kvcache during generation
parent
f8448524
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
2 deletions
+10
-2
flash_attn/models/baichuan.py
flash_attn/models/baichuan.py
+1
-2
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+8
-0
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+1
-0
No files found.
flash_attn/models/baichuan.py
View file @
3f7d5786
# Copyright (c) 2023, GGGGGGXY.
# Copyright (c) 2023, GGGGGGXY
, Tri Dao
.
import
math
import
math
import
json
import
json
...
@@ -14,7 +14,6 @@ from einops import rearrange
...
@@ -14,7 +14,6 @@ from einops import rearrange
from
transformers
import
GPT2Config
,
AutoConfig
,
PretrainedConfig
from
transformers
import
GPT2Config
,
AutoConfig
,
PretrainedConfig
# only support Baichuan-7B now
def
remap_state_dict_hf_baichuan
(
state_dict
,
config
):
def
remap_state_dict_hf_baichuan
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
"^model."
,
"transformer."
,
key
)
return
re
.
sub
(
r
"^model."
,
"transformer."
,
key
)
...
...
flash_attn/modules/mha.py
View file @
3f7d5786
...
@@ -501,6 +501,7 @@ class MHA(nn.Module):
...
@@ -501,6 +501,7 @@ class MHA(nn.Module):
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
else
inference_params
.
seqlen_offset
)
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
0
],
...
@@ -513,6 +514,7 @@ class MHA(nn.Module):
...
@@ -513,6 +514,7 @@ class MHA(nn.Module):
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
)
return
context
return
context
...
@@ -534,6 +536,7 @@ class MHA(nn.Module):
...
@@ -534,6 +536,7 @@ class MHA(nn.Module):
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
else
inference_params
.
seqlen_offset
)
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
return
flash_attn_with_kvcache
(
return
flash_attn_with_kvcache
(
q
,
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
0
],
...
@@ -543,6 +546,7 @@ class MHA(nn.Module):
...
@@ -543,6 +546,7 @@ class MHA(nn.Module):
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
)
def
forward
(
def
forward
(
...
@@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
...
@@ -847,6 +851,7 @@ class ParallelMHA(nn.Module):
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
else
inference_params
.
seqlen_offset
)
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
0
],
...
@@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
...
@@ -859,6 +864,7 @@ class ParallelMHA(nn.Module):
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
causal
=
self
.
inner_cross_attn
.
causal
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
rotary_interleaved
=
self
.
rotary_emb
.
interleaved
if
self
.
rotary_emb_dim
>
0
else
False
,
alibi_slopes
=
alibi_slopes
,
)
)
return
context
return
context
...
@@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
...
@@ -876,6 +882,7 @@ class ParallelMHA(nn.Module):
if
inference_params
.
lengths_per_sample
is
not
None
if
inference_params
.
lengths_per_sample
is
not
None
else
inference_params
.
seqlen_offset
else
inference_params
.
seqlen_offset
)
)
alibi_slopes
=
getattr
(
self
.
inner_cross_attn
,
"alibi_slopes"
,
None
)
context
=
flash_attn_with_kvcache
(
context
=
flash_attn_with_kvcache
(
q
,
q
,
kv_cache
[:,
:,
0
],
kv_cache
[:,
:,
0
],
...
@@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
...
@@ -885,6 +892,7 @@ class ParallelMHA(nn.Module):
cache_seqlens
=
cache_seqlens
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
softmax_scale
=
self
.
inner_cross_attn
.
softmax_scale
,
causal
=
self
.
inner_cross_attn
.
causal
,
causal
=
self
.
inner_cross_attn
.
causal
,
alibi_slopes
=
alibi_slopes
,
)
)
return
context
return
context
...
...
tests/models/test_baichuan.py
View file @
3f7d5786
# Copyright (c) 2023, Tri Dao.
import
os
import
os
import
time
import
time
from
pathlib
import
Path
from
pathlib
import
Path
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment