Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c29f7313
Commit
c29f7313
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
tidy flash_fwd_kernel
parent
3f2484ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
18 deletions
+0
-18
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+0
-18
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
c29f7313
...
...
@@ -566,7 +566,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
:
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
// block addresses are later resolved per-thread
const
index_t
row_offset_k__shadow
=
block_table
[(
n_block_max
-
1
)
*
kBlockN
/
params
.
page_block_size
]
*
params
.
k_batch_stride
+
(((
n_block_max
-
1
)
*
kBlockN
)
%
params
.
page_block_size
)
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
...
...
@@ -580,9 +579,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gK__shadow
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k__shadow
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
...
...
@@ -602,7 +598,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_KV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK__shadow
=
gmem_thr_copy_KV
.
partition_S
(
gK__shadow
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
...
...
@@ -754,14 +749,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
tKgK
.
data
()
=
tKgK
.
data
()
+
(
-
int
(
kBlockN
*
params
.
k_row_stride
));
}
else
{
if
(
n_block
>
n_block_copy_min
)
{
// const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
// const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
// const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
// const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
// const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
// const int offset_diff = block_table_offset_next - block_table_offset_cur;
// tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
// tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
tKgK
.
data
()
=
tKgK
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
,
params
.
page_block_size
,
...
...
@@ -854,11 +841,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
if
(
block_table
==
nullptr
)
{
tVgV
.
data
()
=
tVgV
.
data
()
+
(
-
int
(
kBlockN
*
params
.
v_row_stride
));
}
else
{
// const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
// const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
// const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
// const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
// tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
tVgV
.
data
()
=
tVgV
.
data
()
+
flash
::
advance_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block
+
1
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment