Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
58611cd9
Commit
58611cd9
authored
Mar 26, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
resolve page offsets absolutely not relatively
parent
10b6f3a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
38 deletions
+10
-38
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-8
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+2
-30
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
58611cd9
...
@@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tVsV
=
make_tensor
(
tVsV_
.
data
(),
reshape_thread_tile
(
tVsV_
.
layout
()));
Tensor
tVsV
=
make_tensor
(
tVsV_
.
data
(),
reshape_thread_tile
(
tVsV_
.
layout
()));
if
(
block_table
!=
nullptr
)
{
if
(
block_table
!=
nullptr
)
{
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
init
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolve
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolve
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
...
@@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
if
(
n_block
>
n_block_copy_min
)
{
if
(
n_block
>
n_block_copy_min
)
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
}
}
...
@@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
@@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
@@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
}
else
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
@@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
}
else
{
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
csrc/flash_attn/src/utils.h
View file @
58611cd9
...
@@ -292,11 +292,11 @@ void cp_async_wait() {
...
@@ -292,11 +292,11 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves
initial base
offset of a slice of a paged kv copy from gmem.
// resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
// assumes that the tensor has already been positioned at the correct head.
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__forceinline__
__device__
__forceinline__
__device__
int
init
_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
int
resolve
_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
...
@@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
...
@@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
+
page_offset
*
row_stride
+
page_offset
*
row_stride
+
col_offset
;
+
col_offset
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// advances base address of a slice of a paged copy from gmem
template
<
typename
Kernel_traits
>
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
virtual_page_idx_cur
=
global_row_offset_cur
/
page_block_size
;
const
int
virtual_page_idx_next
=
global_row_offset_next
/
page_block_size
;
const
int
table_diff
=
block_table
[
virtual_page_idx_next
]
-
block_table
[
virtual_page_idx_cur
];
const
int
offset_diff
=
page_offset_next
-
page_offset_cur
;
return
table_diff
*
page_stride
+
offset_diff
*
row_stride
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment