Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
79681482
Commit
79681482
authored
Mar 26, 2024
by
skrider
Browse files
allow smaller page sizes in varlen api
parent
a63157ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
2 deletions
+2
-2
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+1
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-1
No files found.
csrc/flash_attn/flash_api.cpp
View file @
79681482
...
@@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
num_blocks
=
!
paged_KV
?
0
:
k
.
size
(
0
);
const
int
num_blocks
=
!
paged_KV
?
0
:
k
.
size
(
0
);
const
int
page_block_size
=
!
paged_KV
?
1
:
k
.
size
(
1
);
const
int
page_block_size
=
!
paged_KV
?
1
:
k
.
size
(
1
);
TORCH_CHECK
(
!
paged_KV
||
page_block_size
%
25
6
==
0
,
"Paged KV cache block size must be divisible by
25
6"
);
TORCH_CHECK
(
!
paged_KV
||
page_block_size
%
1
6
==
0
,
"Paged KV cache block size must be divisible by
1
6"
);
if
(
max_seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
max_seqlen_q
==
1
&&
!
alibi_slopes_
.
has_value
())
{
is_causal
=
false
;
}
// causal=true is the same as causal=false in this case
if
(
is_causal
)
{
window_size_right
=
0
;
}
if
(
is_causal
)
{
window_size_right
=
0
;
}
...
...
tests/test_flash_attn.py
View file @
79681482
...
@@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
...
@@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
],
],
)
)
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def
test_flash_attn_varlen_causal
(
def
test_flash_attn_varlen_causal
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
local
,
paged_kv_block_size
,
dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment