Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
58611cd9
"vscode:/vscode.git/clone" did not exist on "4e49542a3683c657d5bd9334f1e6f1ba59a426fa"
Commit
58611cd9
authored
Mar 26, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
resolve page offsets absolutely not relatively
parent
10b6f3a8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
38 deletions
+10
-38
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+8
-8
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+2
-30
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
58611cd9
...
...
@@ -609,9 +609,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tVsV
=
make_tensor
(
tVsV_
.
data
(),
reshape_thread_tile
(
tVsV_
.
layout
()));
if
(
block_table
!=
nullptr
)
{
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
init
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolve
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolve
_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
...
...
@@ -769,9 +769,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
if
(
n_block
>
n_block_copy_min
)
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
}
...
...
@@ -865,7 +865,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
...
@@ -897,7 +897,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
@@ -937,7 +937,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
tVgV
.
data
()
=
tV
gV
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
...
...
@@ -955,7 +955,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
tKgK
.
data
()
=
tK
gK
.
data
()
+
flash
::
advanc
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
tKgK
.
data
()
=
gK
.
data
()
+
flash
::
resolv
e_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
...
...
csrc/flash_attn/src/utils.h
View file @
58611cd9
...
...
@@ -292,11 +292,11 @@ void cp_async_wait() {
////////////////////////////////////////////////////////////////////////////////////////////////////
// resolves
initial base
offset of a slice of a paged kv copy from gmem.
// resolves offset of a slice of a paged kv copy from gmem.
// assumes that the tensor has already been positioned at the correct head.
template
<
typename
Kernel_traits
>
__forceinline__
__device__
int
init
_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
int
resolve
_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
...
...
@@ -313,34 +313,6 @@ int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, cons
+
page_offset
*
row_stride
+
col_offset
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// advances base address of a slice of a paged copy from gmem
template
<
typename
Kernel_traits
>
__forceinline__
__device__
int
advance_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
global_row_offset_cur
=
block_row_offset
+
n_block
*
kBlockN
;
const
int
global_row_offset_next
=
block_row_offset
+
(
n_block
-
1
)
*
kBlockN
;
const
int
page_offset_cur
=
global_row_offset_cur
%
page_block_size
;
const
int
page_offset_next
=
global_row_offset_next
%
page_block_size
;
const
int
virtual_page_idx_cur
=
global_row_offset_cur
/
page_block_size
;
const
int
virtual_page_idx_next
=
global_row_offset_next
/
page_block_size
;
const
int
table_diff
=
block_table
[
virtual_page_idx_next
]
-
block_table
[
virtual_page_idx_cur
];
const
int
offset_diff
=
page_offset_next
-
page_offset_cur
;
return
table_diff
*
page_stride
+
offset_diff
*
row_stride
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment