Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e524c2ca
Commit
e524c2ca
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
allow small page sizes in flash api
parent
b1c18ca1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
5 deletions
+4
-5
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+1
-1
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+2
-3
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-1
No files found.
csrc/flash_attn/flash_api.cpp
View file @
e524c2ca
...
@@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
...
@@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
max_num_blocks_per_seq
=
!
paged_KV
?
0
:
block_table
.
size
(
1
);
const
int
num_blocks
=
!
paged_KV
?
0
:
kcache
.
size
(
0
);
const
int
num_blocks
=
!
paged_KV
?
0
:
kcache
.
size
(
0
);
const
int
page_block_size
=
!
paged_KV
?
1
:
kcache
.
size
(
1
);
const
int
page_block_size
=
!
paged_KV
?
1
:
kcache
.
size
(
1
);
TORCH_CHECK
(
!
paged_KV
||
page_block_size
%
25
6
==
0
,
"Paged KV cache block size must be divisible by
25
6"
);
TORCH_CHECK
(
!
paged_KV
||
page_block_size
%
1
6
==
0
,
"Paged KV cache block size must be divisible by
1
6"
);
const
int
seqlen_k
=
!
paged_KV
?
kcache
.
size
(
1
)
:
max_num_blocks_per_seq
*
page_block_size
;
const
int
seqlen_k
=
!
paged_KV
?
kcache
.
size
(
1
)
:
max_num_blocks_per_seq
*
page_block_size
;
const
int
num_heads_k
=
kcache
.
size
(
2
);
const
int
num_heads_k
=
kcache
.
size
(
2
);
const
int
batch_size_c
=
!
paged_KV
?
kcache
.
size
(
0
)
:
batch_size
;
const
int
batch_size_c
=
!
paged_KV
?
kcache
.
size
(
0
)
:
batch_size
;
...
...
csrc/flash_attn/src/utils.h
View file @
e524c2ca
...
@@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
...
@@ -328,12 +328,11 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
// base row of thread's slice relative to the page
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
...
...
tests/test_flash_attn.py
View file @
e524c2ca
...
@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
...
@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
256
,
512
])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False, True])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment