Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f67a6edf
Commit
f67a6edf
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
implement kv page iteration functions
parent
d4da6bfc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
27 deletions
+71
-27
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+22
-27
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+49
-0
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
f67a6edf
...
@@ -621,16 +621,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -621,16 +621,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// We move K and V to the last block.
// We move K and V to the last block.
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
int
bidb_cache
=
params
.
cache_batch_idx
==
nullptr
?
bidb
:
params
.
cache_batch_idx
[
bidb
];
const
int
*
block_table
=
params
.
block_table
==
nullptr
?
nullptr
:
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
*
block_table
=
params
.
block_table
==
nullptr
?
nullptr
:
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
block_table_idx
=
block_table
==
nullptr
?
0
:
(
n_block_max
-
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset
=
block_table
==
nullptr
?
0
:
(
n_block_max
-
1
)
*
kBlockN
-
block_table_idx
*
params
.
page_block_size
;
const
index_t
row_offset_k
=
block_table
==
nullptr
const
index_t
row_offset_k
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
?
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
:
block_table
[
block_table_idx
]
*
params
.
k_batch_stride
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
:
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
bidh
/
params
.
h_h_k_ratio
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
,
params
.
k_head_stride
);
const
index_t
row_offset_v
=
block_table
==
nullptr
const
index_t
row_offset_v
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
:
block_table
[
block_table_idx
]
*
params
.
v_batch_stride
+
block_table_offset
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
:
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
bidh
/
params
.
h_h_k_ratio
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
,
params
.
v_head_stride
);
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -842,14 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -842,14 +842,18 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
if
(
n_block
>
n_block_copy_min
)
{
if
(
n_block
>
n_block_copy_min
)
{
const
int
block_table_idx_cur
=
n_block
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
const
int
block_table_offset_cur
=
n_block
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
// const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
const
int
block_table_idx_next
=
(
n_block
-
1
)
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
// const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
const
int
table_diff
=
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
];
// const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
tVgV
.
data
()
=
tVgV
.
data
()
+
table_diff
*
params
.
v_batch_stride
+
offset_diff
*
params
.
v_row_stride
;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
tKgK
.
data
()
=
tKgK
.
data
()
+
table_diff
*
params
.
k_batch_stride
+
offset_diff
*
params
.
k_row_stride
;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
}
}
}
}
...
@@ -973,11 +977,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -973,11 +977,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
const
int
block_table_idx_cur
=
n_block
*
kBlockN
/
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
const
int
block_table_offset_cur
=
n_block
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
params
.
k_batch_stride
,
params
.
k_row_stride
);
const
int
block_table_idx_next
=
(
n_block
-
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
...
@@ -1016,11 +1017,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1016,11 +1017,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
const
int
block_table_idx_cur
=
(
n_block
+
1
)
*
kBlockN
/
params
.
page_block_size
;
tVgV
.
data
()
=
tVgV
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
const
int
block_table_offset_cur
=
(
n_block
+
1
)
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
const
int
block_table_idx_next
=
n_block
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -1037,11 +1035,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1037,11 +1035,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
const
int
block_table_idx_cur
=
n_block
*
kBlockN
/
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
const
int
block_table_offset_cur
=
n_block
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
const
int
block_table_idx_next
=
(
n_block
-
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
...
...
csrc/flash_attn/src/utils.h
View file @
f67a6edf
...
@@ -292,6 +292,55 @@ void cp_async_wait() {
...
@@ -292,6 +292,55 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base address of a slice of a paged kv copy from gmem
template
<
typename
Kernel_traits
>
__forceinline__
__device__
int
init_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
hidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
,
const
int
head_stride
)
{
// base col of thread's slice relative to the block
const
int
col_offset
=
tidx
%
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemElemsPerLoad
;
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the page
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
return
block_table
[
virtual_page_idx
]
*
page_stride
+
page_offset
*
row_stride
+
hidx
*
head_stride
+
col_offset
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// advances base address of a slice of a paged copy from gmem
template
<
typename
Kernel_traits
>
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
Kernel_traits
::
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the page
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
virtual_page_idx_cur
=
global_row_offset_cur
/
page_block_size
;
const
int
virtual_page_idx_next
=
global_row_offset_next
/
page_block_size
;
const
int
table_diff
=
block_table
[
virtual_page_idx_next
]
-
block_table
[
virtual_page_idx_cur
];
const
int
offset_diff
=
page_offset_next
-
page_offset_cur
;
return
table_diff
*
page_stride
+
offset_diff
*
row_stride
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment