Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
ebf30641
Unverified
Commit
ebf30641
authored
Sep 22, 2025
by
zhang
Committed by
GitHub
Sep 22, 2025
Browse files
Refine handling for q/v sequence length equals zero. (#92)
parent
261330bb
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
20 additions
and
11 deletions
+20
-11
csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
...ollective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
+4
-1
csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
...ollective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
+0
-4
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
.../sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
+3
-0
csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
...ctive/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
+0
-4
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
...00/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
+3
-0
csrc/sm100/device/fmha.hpp
csrc/sm100/device/fmha.hpp
+5
-0
csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp
csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp
+1
-1
csrc/sm100/kernel/fmha_tile_scheduler.hpp
csrc/sm100/kernel/fmha_tile_scheduler.hpp
+1
-1
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+3
-0
No files found.
csrc/sm100/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp
View file @
ebf30641
...
@@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
...
@@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
get
<
0
>
(
problem_shape_O
).
max_length
=
max
(
1
,
max_length_q
);
// for variable sequence lenght, the batch is in units of row_stride
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dO
)
=
get
<
0
>
(
dO
);
get
<
2
,
1
>
(
dO
)
=
get
<
0
>
(
dO
);
get
<
2
,
1
>
(
problem_shape_O
)
=
max_length_q
*
(
1
+
get
<
2
,
1
>
(
problem_shape_O
));
get
<
2
,
1
>
(
problem_shape_O
)
=
max
(
1
,
max_length_q
*
(
1
+
get
<
2
,
1
>
(
problem_shape_O
))
)
;
// offset ptr by the amount we add back in later
// offset ptr by the amount we add back in later
ptr_O
-=
max_length_q
*
get
<
0
>
(
dO
);
ptr_O
-=
max_length_q
*
get
<
0
>
(
dO
);
}
}
}
else
{
get
<
0
>
(
problem_shape_O
)
=
max
(
1
,
get
<
0
>
(
problem_shape_O
));
}
}
auto
tma_store_o
=
make_tma_copy
(
auto
tma_store_o
=
make_tma_copy
(
...
...
csrc/sm100/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp
View file @
ebf30641
...
@@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
...
@@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float
lse
=
-
INFINITY
;
float
lse
=
-
INFINITY
;
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if
(
threadIdx
.
x
%
128
==
0
&&
block0
())
{
DSHOW
(
sO
);
}
#if 1
#if 1
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
...
...
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
View file @
ebf30641
...
@@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized {
...
@@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized {
problem_shape_qk
=
problem_shape
;
problem_shape_qk
=
problem_shape
;
}
}
get
<
0
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
0
>
(
problem_shape_qk
));
get
<
1
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
1
>
(
problem_shape_qk
));
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
problem_shape_qk
,
problem_shape_qk
,
typename
CollectiveMmaQK
::
Arguments
{
typename
CollectiveMmaQK
::
Arguments
{
...
...
csrc/sm100/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp
View file @
ebf30641
...
@@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized {
...
@@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized {
float
lse
=
-
INFINITY
;
float
lse
=
-
INFINITY
;
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
int
thread_idx
=
threadIdx
.
x
%
(
4
*
NumThreadsPerWarp
);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if
(
threadIdx
.
x
%
128
==
0
&&
block0
())
{
DSHOW
(
sO
);
}
#if 1
#if 1
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
using
ElementOut
=
typename
CollectiveEpilogue
::
ElementOut
;
...
...
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
View file @
ebf30641
...
@@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
...
@@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));;
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));;
}
}
get
<
0
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
0
>
(
problem_shape_qk
));
get
<
1
>
(
problem_shape_qk
)
=
max
(
1
,
get
<
1
>
(
problem_shape_qk
));
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
...
...
csrc/sm100/device/fmha.hpp
View file @
ebf30641
...
@@ -208,6 +208,11 @@ public:
...
@@ -208,6 +208,11 @@ public:
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
grid
=
get_grid_shape
(
params
);
dim3
const
grid
=
get_grid_shape
(
params
);
// No need to launch the kernel
if
(
grid
.
x
==
0
||
grid
.
y
==
0
||
grid
.
z
==
0
)
{
return
Status
::
kSuccess
;
}
// configure smem size and carveout
// configure smem size and carveout
int
smem_size
=
Kernel
::
SharedStorageSize
;
int
smem_size
=
Kernel
::
SharedStorageSize
;
...
...
csrc/sm100/kernel/fmha_causal_tile_scheduler.hpp
View file @
ebf30641
...
@@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler {
...
@@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler {
return
Params
{
return
Params
{
num_blocks
,
num_blocks
,
{
size
<
3
,
0
>
(
problem_size
)
},
{
num_m_blocks
},
{
size
<
3
,
1
>
(
problem_size
)
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
max
(
1
,
num_m_blocks
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
hw_info
hw_info
};
};
}
}
...
...
csrc/sm100/kernel/fmha_tile_scheduler.hpp
View file @
ebf30641
...
@@ -123,7 +123,7 @@ struct PersistentTileScheduler {
...
@@ -123,7 +123,7 @@ struct PersistentTileScheduler {
return
Params
{
return
Params
{
num_blocks
,
num_blocks
,
{
num_m_blocks
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
{
max
(
1
,
num_m_blocks
)
},
{
size
<
3
,
0
>
(
problem_size
)
},
{
size
<
3
,
1
>
(
problem_size
)
},
hw_info
hw_info
};
};
}
}
...
...
tests/test_fmha_sm100.py
View file @
ebf30641
...
@@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window):
...
@@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window):
def
assert_close
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
def
assert_close
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
close_tensor
=
torch
.
isclose
(
x
.
to
(
torch
.
float32
),
y
.
to
(
torch
.
float32
),
rtol
=
1e-5
,
atol
=
1e-5
)
if
close_tensor
.
all
():
return
x
,
y
=
x
.
double
(),
y
.
double
()
x
,
y
=
x
.
double
(),
y
.
double
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
RMSE
=
((
x
-
y
)
*
(
x
-
y
)).
mean
().
sqrt
().
item
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
((
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment