Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
582eb8fc
Unverified
Commit
582eb8fc
authored
Sep 03, 2024
by
Jongseok Park
Committed by
GitHub
Sep 03, 2024
Browse files
Fix params.seqlen_k reference in the splitkv kernel to binfo.actual_seqlen_k (#18)
parent
f9d2c100
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
1 deletion
+1
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+1
-1
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
582eb8fc
...
...
@@ -531,7 +531,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_blocks_per_split
=
((
binfo
.
actual_
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
!
Is_local
?
n_split_idx
*
n_blocks_per_split
:
std
::
max
(
n_split_idx
*
n_blocks_per_split
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment