Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
01a27728
Commit
01a27728
authored
Apr 28, 2025
by
ljss
Browse files
Fix synchronization issues
parent
70b94685
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
4 deletions
+10
-4
.gitignore
.gitignore
+1
-0
csrc/kernels/splitkv_mla.cu
csrc/kernels/splitkv_mla.cu
+7
-3
csrc/kernels/traits.h
csrc/kernels/traits.h
+2
-1
No files found.
.gitignore
View file @
01a27728
...
...
@@ -6,3 +6,4 @@ dist/
*perf.csv
*.png
/.vscode
compile_commands.json
csrc/kernels/splitkv_mla.cu
View file @
01a27728
...
...
@@ -1017,13 +1017,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
cudaGridDependencySynchronize
();
int
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
+
partition_idx
*
TileSchedulerMetaDataSize
;
int4
tile_scheduler_metadata
=
__ldg
(
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
));
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
int4
tile_scheduler_metadata
=
*
(
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
));
int
begin_idx
=
tile_scheduler_metadata
.
x
;
int
begin_seqlen
=
tile_scheduler_metadata
.
y
;
int
end_idx
=
tile_scheduler_metadata
.
z
;
int
end_seqlen
=
tile_scheduler_metadata
.
w
;
if
(
begin_idx
>=
params
.
b
)
return
;
int
begin_n_split_idx
=
__ldg
(
tile_scheduler_metadata_ptr
+
4
);
int
begin_n_split_idx
=
*
(
tile_scheduler_metadata_ptr
+
4
);
// Copy the first Q
launch_q_copy
<
T
>
(
tma_params
,
begin_idx
,
m_block_idx
,
k_head_idx
,
sQ
,
barrier_Q
);
...
...
@@ -1123,6 +1124,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
// Issue P0 = Q @ K0^T, wait
warpgroup_cooperative_qkt_gemm_no_pipeline
<
T
>
(
sQ
,
sK0
,
rP0
,
idx_in_warpgroup
);
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
NamedBarrier
::
arrive_and_wait
(
128
,
NamedBarriers
::
sMInitialized
);
cute
::
warpgroup_wait
<
0
>
();
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
...
...
@@ -1238,7 +1241,8 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
cute
::
tma_store_wait
<
0
>
();
}
else
{
int
split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
)
+
n_split_idx
;
// Don't use __ldg because of PDL and instruction reordering
int
split_idx
=
params
.
num_splits_ptr
[
batch_idx
]
+
n_split_idx
;
float
*
oaccum_ptr
=
(
float
*
)
params
.
oaccum_ptr
+
((
split_idx
*
params
.
h_k
+
k_head_idx
)
*
params
.
q_seq_per_hk
+
m_block_idx
*
T
::
BLOCK_SIZE_M
)
*
T
::
HEAD_DIM_V
;
// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float
*
softmax_lseaccum_ptr
=
(
float
*
)
params
.
softmax_lseaccum_ptr
+
(
split_idx
*
params
.
h_k
+
k_head_idx
)
*
params
.
q_seq_per_hk
+
m_block_idx
*
T
::
BLOCK_SIZE_M
;
// (BLOCK_SIZE_M) : (1)
Tensor
gOAccum
=
make_tensor
(
make_gmem_ptr
(
oaccum_ptr
),
Layout
<
...
...
csrc/kernels/traits.h
View file @
01a27728
...
...
@@ -102,5 +102,6 @@ enum NamedBarriers : int {
sScale0Ready
=
0
,
sScale1Ready
=
1
,
sP0Ready
=
2
,
rO1sP0sV0RIssued
=
3
rO1sP0sV0RIssued
=
3
,
sMInitialized
=
4
,
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment