Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
3969f20b
"wrappers/vscode:/vscode.git/clone" did not exist on "cebb9934fb5dcbb0524d9f12d5f0d1200ceffe7f"
Commit
3969f20b
authored
Sep 29, 2025
by
Shengyu Liu
Browse files
Merge remote-tracking branch 'github/main' into open-source-h
parents
7232d69d
ebf30641
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
17 additions
and
11 deletions
+17
-11
csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
...ollective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
+4
-1
csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
...ollective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
+0
-4
csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp
.../dense/collective/sm100_fmha_load_tma_warpspecialized.hpp
+3
-0
csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
...ctive/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
+0
-4
csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
...se/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
+3
-0
csrc/sm100/prefill/dense/device/fmha.hpp
csrc/sm100/prefill/dense/device/fmha.hpp
+5
-0
csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp
...sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp
+1
-1
csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp
csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp
+1
-1
No files found.
csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
View file @
3969f20b
...
...
@@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
get
<
0
>
(
problem_shape_O
).
max_length
=
max
(
1
,
max_length_q
);
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dO
)
=
get
<
0
>
(
dO
);
get
<
2
,
1
>
(
problem_shape_O
)
=
max_length_q
*
(
1
+
get
<
2
,
1
>
(
problem_shape_O
));
get
<
2
,
1
>
(
problem_shape_O
)
=
max
(
1
,
max_length_q
*
(
1
+
get
<
2
,
1
>
(
problem_shape_O
))
)
;
// offset ptr by the amount we add back in later
ptr_O
-=
max_length_q
*
get
<
0
>
(
dO
);
}
}
else
{
get
<
0
>
(
problem_shape_O
)
=
max
(
1
,
get
<
0
>
(
problem_shape_O
));
}
auto
tma_store_o
=
make_tma_copy
(
...
...
csrc/sm100/prefill/dense/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
View file @
3969f20b
...
...
@@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float
lse
=
-
INFINITY
;
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if
(
threadIdx
.
x
%
128
==
0
&&
block0
())
{
DSHOW
(
sO
);
}
#if 1
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
...
...
csrc/sm100/prefill/dense/collective/sm100_fmha_load_tma_warpspecialized.hpp
View file @
3969f20b
...
...
@@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized {
problem_shape_qk
=
problem_shape
;
}
get
<
0
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
0
>
(
problem_shape_qk
));
get
<
1
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
1
>
(
problem_shape_qk
));
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
problem_shape_qk
,
typename
CollectiveMmaQK
::
Arguments
{
...
...
csrc/sm100/prefill/dense/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
View file @
3969f20b
...
...
@@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized {
float
lse
=
-
INFINITY
;
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if
(
threadIdx
.
x
%
128
==
0
&&
block0
())
{
DSHOW
(
sO
);
}
#if 1
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
...
...
csrc/sm100/prefill/dense/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
View file @
3969f20b
...
...
@@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));;
}
get
<
0
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
0
>
(
problem_shape_qk
));
get
<
1
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
1
>
(
problem_shape_qk
));
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
...
...
csrc/sm100/prefill/dense/device/fmha.hpp
View file @
3969f20b
...
...
@@ -208,6 +208,11 @@ public:
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
grid
=
get_grid_shape
(
params
);
// No need to launch the kernel
if
(
grid
.
x
==
0
||
grid
.
y
==
0
||
grid
.
z
==
0
)
{
return
Status
::
kSuccess
;
}
// configure smem size and carveout
int
smem_size
=
Kernel
::
SharedStorageSize
;
...
...
csrc/sm100/prefill/dense/kernel/fmha_causal_tile_scheduler.hpp
View file @
3969f20b
...
...
@@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler {
return
Params
{
num_blocks
,
{
size
<
3
,
0
>
(
problem_size
)
},
{
num_m_blocks
},
{
size
<
3
,
1
>
(
problem_size
)
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
max
(
1
,
num_m_blocks
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
hw_info
};
}
...
...
csrc/sm100/prefill/dense/kernel/fmha_tile_scheduler.hpp
View file @
3969f20b
...
...
@@ -123,7 +123,7 @@ struct PersistentTileScheduler {
return
Params
{
num_blocks
,
{
num_m_blocks
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
{
max
(
1
,
num_m_blocks
)
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
hw_info
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment