Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b1c18ca1
"tests/models/autoencoders/test_models_vae.py" did not exist on "63f767ef15fa59704272ac7320ec23b8c15de246"
Commit
b1c18ca1
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
paged copy refactor working for page size 256
parent
446204c7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
26 deletions
+31
-26
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+23
-21
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+8
-5
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
b1c18ca1
...
@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -620,19 +620,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
#if 1
#if 1
KIN_PRINT
([
&
]()
{
//
KIN_PRINT([&]() {
for
(
int
i
=
0
;
i
<
n_block_max
;
i
++
)
{
//
for (int i = 0; i < n_block_max; i++) {
printf
(
"%d "
,
block_table
[
i
]);
//
printf("%d ", block_table[i]);
}
//
}
}())
//
}())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// if (thread0()) {
// gK.data() = tKgK.data();
// gK.data() = tKgK.data();
// }
// }
KIN_PRINT
(
print_tensor
(
tKgK
))
//
KIN_PRINT(print_tensor(tKgK))
KIN_PRINT
(
print_tensor
(
gK
))
//
KIN_PRINT(print_tensor(gK))
KIN_PRINT
(
print_tensor
(
tKgK__shadow
))
//
KIN_PRINT(print_tensor(tKgK__shadow))
KIN_PRINT
(
print_tensor
(
gK__shadow
))
//
KIN_PRINT(print_tensor(gK__shadow))
#endif
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -783,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
}
}
}
}
...
@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -875,11 +875,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
const
int
block_table_idx_cur
=
(
n_block
+
1
)
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
const
int
block_table_offset_cur
=
(
n_block
+
1
)
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
// const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
const
int
block_table_idx_next
=
n_block
*
kBlockN
/
params
.
page_block_size
;
// const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
// const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
// tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
}
else
{
...
@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -910,8 +912,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
...
@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -950,7 +952,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
...
csrc/flash_attn/src/utils.h
View file @
b1c18ca1
...
@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
...
@@ -310,7 +310,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
KIN_PRINT
(
printf
(
"%d"
,
virtual_page_idx
))
return
block_table
[
virtual_page_idx
]
*
page_stride
return
block_table
[
virtual_page_idx
]
*
page_stride
+
page_offset
*
row_stride
+
page_offset
*
row_stride
...
@@ -324,12 +323,16 @@ template <typename Kernel_traits>
...
@@ -324,12 +323,16 @@ template <typename Kernel_traits>
__forceinline__
__device__
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
return
0
;
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the block
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
Kernel_traits
::
kBlockN
;
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
Kernel_traits
::
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
// base row of thread's slice relative to the page
// base row of thread's slice relative to the page
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment