Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
36916777
Commit
36916777
authored
Feb 11, 2024
by
skrider
Browse files
paged copy refactor working for page size 256
parent
175369fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
26 deletions
+31
-26
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+23
-21
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+8
-5
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
36916777
...
@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
#if 1
#if 1
KIN_PRINT
([
&
]()
{
//
KIN_PRINT([&]() {
for
(
int
i
=
0
;
i
<
n_block_max
;
i
++
)
{
//
for (int i = 0; i < n_block_max; i++) {
printf
(
"%d "
,
block_table
[
i
]);
//
printf("%d ", block_table[i]);
}
//
}
}())
//
}())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// if (thread0()) {
// gK.data() = tKgK.data();
// gK.data() = tKgK.data();
// }
// }
KIN_PRINT
(
print_tensor
(
tKgK
))
//
KIN_PRINT(print_tensor(tKgK))
KIN_PRINT
(
print_tensor
(
gK
))
//
KIN_PRINT(print_tensor(gK))
KIN_PRINT
(
print_tensor
(
tKgK__shadow
))
//
KIN_PRINT(print_tensor(tKgK__shadow))
KIN_PRINT
(
print_tensor
(
gK__shadow
))
//
KIN_PRINT(print_tensor(gK__shadow))
#endif
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
}
}
}
}
...
@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
const
int
block_table_idx_cur
=
(
n_block
+
1
)
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const
int
block_table_offset_cur
=
(
n_block
+
1
)
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
// const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const
int
block_table_idx_next
=
n_block
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
// const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
// tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
}
else
{
...
@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
...
@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
...
csrc/flash_attn/src/utils.h
View file @
36916777
...
@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
...
@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
KIN_PRINT
(
printf
(
"%d"
,
virtual_page_idx
))
return
block_table
[
virtual_page_idx
]
*
page_stride
return
block_table
[
virtual_page_idx
]
*
page_stride
+
page_offset
*
row_stride
+
page_offset
*
row_stride
...
@@ -324,12 +323,16 @@ template <typename Kernel_traits>
...
@@ -324,12 +323,16 @@ template <typename Kernel_traits>
__forceinline__
__device__
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
return
0
;
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the block
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
Kernel_traits
::
kBlockN
;
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
Kernel_traits
::
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
// base row of thread's slice relative to the page
// base row of thread's slice relative to the page
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment