Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a3e06cd5
Commit
a3e06cd5
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
rearrange initial offset computation
parent
f67a6edf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
13 deletions
+20
-13
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+15
-9
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+5
-4
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
a3e06cd5
...
@@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -624,13 +624,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
index_t
row_offset_k
=
block_table
==
nullptr
const
index_t
row_offset_k
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
?
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
:
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
bidh
/
params
.
h_h_k_ratio
,
n_block_max
,
params
.
page_block_size
,
block_table
,
:
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
// block addresses are later resolved per-thread
params
.
k_batch_stride
,
params
.
k_row_stride
,
params
.
k_head_stride
);
const
index_t
row_offset_v
=
block_table
==
nullptr
const
index_t
row_offset_v
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
:
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
bidh
/
params
.
h_h_k_ratio
,
n_block_max
,
params
.
page_block_size
,
block_table
,
:
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
params
.
v_batch_stride
,
params
.
v_row_stride
,
params
.
v_head_stride
);
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
...
@@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -667,6 +665,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tKsK
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tKsK
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
if
(
block_table
!=
nullptr
)
{
tKgK
.
data
()
=
gV
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
#if 1
#if 1
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
...
@@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -850,9 +856,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
}
}
...
@@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -977,7 +983,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
@@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1017,7 +1023,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
tVgV
.
data
()
=
tVgV
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
@@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1035,7 +1041,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tKgK
.
data
()
+
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
csrc/flash_attn/src/utils.h
View file @
a3e06cd5
...
@@ -292,11 +292,12 @@ void cp_async_wait() {
...
@@ -292,11 +292,12 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves initial base address of a slice of a paged kv copy from gmem
// resolves initial base offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__forceinline__
__device__
__forceinline__
__device__
int
init_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
hidx
,
const
int
n_block_max
,
const
int
page_block_size
,
int
init_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
,
const
int
head_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
// base col of thread's slice relative to the block
// base col of thread's slice relative to the block
const
int
col_offset
=
tidx
%
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemElemsPerLoad
;
const
int
col_offset
=
tidx
%
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemElemsPerLoad
;
// base row of thread's slice relative to the block
// base row of thread's slice relative to the block
...
@@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n
...
@@ -310,7 +311,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int hidx, const int n
return
block_table
[
virtual_page_idx
]
*
page_stride
return
block_table
[
virtual_page_idx
]
*
page_stride
+
page_offset
*
row_stride
+
page_offset
*
row_stride
+
hidx
*
head_stride
+
col_offset
;
+
col_offset
;
}
}
...
@@ -321,6 +321,7 @@ template <typename Kernel_traits>
...
@@ -321,6 +321,7 @@ template <typename Kernel_traits>
__forceinline__
__device__
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
return
0
;
// base row of thread's slice relative to the block
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
// base col of thread's slice relative to the entire tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment