Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2d291b0c
Unverified
Commit
2d291b0c
authored
Aug 25, 2025
by
zhang
Committed by
GitHub
Aug 25, 2025
Browse files
Remove tma padding for fwd inputs (#85)
parent
c7590278
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
96 deletions
+68
-96
csrc/sm100/collective/fmha_fusion.hpp
csrc/sm100/collective/fmha_fusion.hpp
+3
-3
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
.../sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
+16
-33
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
...00/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
+16
-33
csrc/sm100/fmha_cutlass_fwd_sm100.cu
csrc/sm100/fmha_cutlass_fwd_sm100.cu
+3
-2
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
+4
-7
csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
...m100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
+23
-6
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+3
-12
No files found.
csrc/sm100/collective/fmha_fusion.hpp
View file @
2d291b0c
...
...
@@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if
constexpr
(
IsQBegin
)
{
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
size
<
0
>
(
tile_shape
),
size
<
1
>
(
tile_shape
))));
}
else
{
const
int
offset_tile_q
=
get
<
1
>
(
problem_size
)
%
get
<
1
>
(
tile_shape
);
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
get
<
0
>
(
tile_shape
)
+
offset_tile_q
,
get
<
1
>
(
tile_shape
))));
const
int
corner_count
=
int
((
get
<
1
>
(
problem_size
)
%
get
<
1
>
(
tile_shape
)
||
get
<
0
>
(
problem_size
)
%
get
<
0
>
(
tile_shape
)
))
;
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
get
<
0
>
(
tile_shape
),
get
<
1
>
(
tile_shape
)))
+
corner_count
);
}
}
...
...
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
View file @
2d291b0c
...
...
@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto
dQ
=
args
.
dQ
;
auto
dK
=
args
.
dK
;
auto
dV
=
args
.
dV
;
auto
problem_shape_qk
=
problem_shape
;
using
IntProblemShape
=
cute
::
tuple
<
int
,
int
,
int
,
cute
::
tuple
<
cute
::
tuple
<
int
,
int
>
,
int
>>
;
IntProblemShape
problem_shape_qk
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dQ
)
=
get
<
0
>
(
dQ
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_q
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_Q
-=
max_length_q
*
get
<
0
>
(
dQ
);
}
}
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ProblemShape
>>
)
{
auto
cumulative_length_kv
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_kv
!=
nullptr
)
{
int
max_length_kv
=
get
<
1
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dK
)
=
get
<
0
>
(
dK
);
get
<
2
,
1
>
(
dV
)
=
get
<
0
>
(
dV
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_kv
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_K
-=
max_length_kv
*
get
<
0
>
(
dK
);
ptr_V
-=
max_length_kv
*
get
<
0
>
(
dV
);
auto
cumulative_length_k
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
&&
cumulative_length_k
!=
nullptr
)
{
get
<
0
>
(
problem_shape_qk
)
=
get
<
0
>
(
problem_shape
).
total_length
;
get
<
1
>
(
problem_shape_qk
)
=
get
<
1
>
(
problem_shape
).
total_length
;
get
<
2
>
(
problem_shape_qk
)
=
get
<
2
>
(
problem_shape
);
get
<
3
>
(
problem_shape_qk
)
=
get
<
3
>
(
problem_shape
);
}
}
else
{
problem_shape_qk
=
problem_shape
;
}
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
...
...
@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape
));
int
q_offs_0
=
0
;
int
q_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
params_problem_shape
).
max_length
;
q_offs_0
=
max_length_q
-
get
<
0
>
(
problem_shape
);
q_offs_2_1
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)]
+
get
<
0
>
(
problem_shape
);
q_offs_0
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)];
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
}
}
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
q_offs_2_1
)),
mQ_qdl_p
);
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mQ_qdl_p
);
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
...
...
@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape
));
int
kv_offs_0
=
0
;
int
kv_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length
!=
nullptr
)
{
int
max_length
=
get
<
1
>
(
params_problem_shape
).
max_length
;
kv_offs_0
=
max_length
-
get
<
1
>
(
problem_shape
);
kv_offs_2_1
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)]
+
get
<
1
>
(
problem_shape
);
kv_offs_0
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)];
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
}
}
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
kv_offs_2_1
)),
mK_kdl_p
);
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mK_kdl_p
);
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
...
...
@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape
));
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
kv_offs_2_1
)),
mV_dkl_p
);
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
_0
{}
)),
mV_dkl_p
);
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
...
...
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
View file @
2d291b0c
...
...
@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
auto
dQ
=
args
.
dQ
;
auto
dK
=
args
.
dK
;
auto
dV
=
args
.
dV
;
auto
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));
using
IntProblemShape
=
cute
::
tuple
<
int
,
int
,
int
,
cute
::
tuple
<
cute
::
tuple
<
int
,
int
>
,
int
>>
;
IntProblemShape
problem_shape_qk
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dQ
)
=
get
<
0
>
(
dQ
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_q
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_Q
-=
max_length_q
*
get
<
0
>
(
dQ
);
}
}
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ProblemShape
>>
)
{
auto
cumulative_length_kv
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_kv
!=
nullptr
)
{
int
max_length_kv
=
get
<
1
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dK
)
=
get
<
0
>
(
dK
);
get
<
2
,
1
>
(
dV
)
=
get
<
0
>
(
dV
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_kv
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_K
-=
max_length_kv
*
get
<
0
>
(
dK
);
ptr_V
-=
max_length_kv
*
get
<
0
>
(
dV
);
auto
cumulative_length_k
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
&&
cumulative_length_k
!=
nullptr
)
{
get
<
0
>
(
problem_shape_qk
)
=
get
<
0
>
(
problem_shape
).
total_length
;
get
<
1
>
(
problem_shape_qk
)
=
get
<
1
>
(
problem_shape
).
total_length
;
get
<
2
>
(
problem_shape_qk
)
=
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
);
get
<
3
>
(
problem_shape_qk
)
=
get
<
3
>
(
problem_shape
);
}
}
else
{
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));;
}
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
...
...
@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape_qk
));
int
q_offs_0
=
0
;
int
q_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
params_problem_shape
).
max_length
;
q_offs_0
=
max_length_q
-
get
<
0
>
(
problem_shape
);
q_offs_2_1
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)]
+
get
<
0
>
(
problem_shape
);
q_offs_0
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)];
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
}
}
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
q_offs_2_1
)),
mQ_qdl_p
);
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mQ_qdl_p
);
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
...
...
@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape_qk
));
int
kv_offs_0
=
0
;
int
kv_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length
!=
nullptr
)
{
int
max_length
=
get
<
1
>
(
params_problem_shape
).
max_length
;
kv_offs_0
=
max_length
-
get
<
1
>
(
problem_shape
);
kv_offs_2_1
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)]
+
get
<
1
>
(
problem_shape
);
kv_offs_0
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)];
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
}
}
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
kv_offs_2_1
)),
mK_kdl_p
);
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mK_kdl_p
);
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
...
...
@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape_v
));
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
kv_offs_2_1
)),
mV_dkl_p
);
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
_0
{}
)),
mV_dkl_p
);
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
...
...
csrc/sm100/fmha_cutlass_fwd_sm100.cu
View file @
2d291b0c
...
...
@@ -18,7 +18,8 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va
static
constexpr
bool
IsVarlen
=
std
::
is_same_v
<
Varlen
,
true_type
>
;
static
constexpr
bool
IsMla
=
std
::
is_same_v
<
Mla
,
true_type
>
;
static
constexpr
bool
IsCausalMask
=
std
::
is_same_v
<
Mask
,
CausalMask
<
false
>>
;
using
Option
=
std
::
conditional_t
<
IsCausalMask
,
Option
<
Tag
::
kIsPersistent
,
false_type
>
,
using
Option
=
std
::
conditional_t
<
IsCausalMask
||
(
IsVarlen
),
Option
<
Tag
::
kIsPersistent
,
false_type
>
,
Option
<
Tag
::
kIsPersistent
,
true_type
>>
;
run_fmha_fwd
<
Element
,
ElementOut
,
IsVarlen
,
IsMla
,
Mask
,
Option
>
(
...
...
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
View file @
2d291b0c
...
...
@@ -143,8 +143,8 @@ struct FwdRunner {
ProblemShapeType
problem_size_for_launch
;
get
<
0
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_q
};
get
<
1
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_kv
};
get
<
0
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_q
,
nullptr
,
total_seqlen_q
};
get
<
1
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_kv
,
nullptr
,
total_seqlen_kv
};
get
<
2
>
(
problem_size_for_launch
)
=
get
<
2
>
(
problem_size
);
get
<
3
>
(
problem_size_for_launch
)
=
get
<
3
>
(
problem_size
);
...
...
@@ -206,10 +206,6 @@ struct FwdRunner {
void
*
q_ptr
,
void
*
k_ptr
,
void
*
v_ptr
,
void
*
o_ptr
,
void
*
lse_ptr
,
void
*
cumulative_length_q
,
void
*
cumulative_length_kv
)
{
auto
problem_shape_
=
problem_shape
;
if
constexpr
(
kIsVarlen
)
{
get
<
0
>
(
problem_shape_
).
cumulative_length
=
static_cast
<
int
*>
(
cumulative_length_q
);
get
<
1
>
(
problem_shape_
).
cumulative_length
=
static_cast
<
int
*>
(
cumulative_length_kv
);
}
typename
Operation
::
Arguments
arguments
{
problem_shape_
,
...
...
@@ -230,6 +226,7 @@ struct FwdRunner {
int
total_seqlen_q
=
q
.
size
(
0
);
int
total_seqlen_kv
=
k
.
size
(
0
);
ProblemShapeType
problem_shape
=
initialize
(
options
,
max_seqlen_q
,
max_seqlen_kv
,
total_seqlen_q
,
total_seqlen_kv
,
cumulative_seqlen_q
.
data_ptr
(),
cumulative_seqlen_kv
.
data_ptr
());
...
...
@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v
auto
options
=
get_options
();
if
(
options
.
h
%
cutlass
::
fmha
::
kernel
::
CausalIndividualTileScheduler
::
TileH
==
0
&&
(
!
std
::
is_same_v
<
ActiveMask
,
NoMask
>
))
{
(
std
::
is_same_v
<
ActiveMask
,
CausalMask
<
false
>>
||
std
::
is_same_v
<
ActiveMask
,
CausalMask
<
true
>
>
))
{
FwdRunner
<
kIsMla
,
true
,
kIsVarlen
,
DTypeIn
,
DTypeOut
,
ActiveMask
,
KernelOptions
...
>
runner
;
runner
.
run
(
options
,
hw_info
,
q
,
k
,
v
,
o
,
lse
,
scale_softmax
,
workspace
,
cumulative_seqlen_q
,
cumulative_seqlen_kv
,
max_seqlen_q
,
max_seqlen_kv
);
...
...
csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
View file @
2d291b0c
...
...
@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else
if
(
role
==
WarpRole
::
Correction
)
{
cutlass
::
arch
::
warpgroup_reg_dealloc
<
NumRegsCorrection
>
();
bool
has_valid
=
false
;
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
...
...
@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
}
has_valid
=
true
;
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
mainloop
.
correction_empty
(
blk_coord
,
...
...
@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
if
constexpr
(
NumWarpsEpilogue
==
0
)
{
static_assert
(
NumWarpsCorrection
==
1
);
if
(
has_valid
)
{
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
}
else
if
(
role
==
WarpRole
::
MMA
)
{
warpgroup_reg_set
<
NumRegsOther
>
();
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
bool
allocated
=
false
;
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
...
...
@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
}
if
(
!
allocated
)
{
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
allocated
=
true
;
}
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
continue
;
}
...
...
@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else
if
(
role
==
WarpRole
::
Epilogue
)
{
warpgroup_reg_set
<
NumRegsOther
>
();
bool
has_valid
=
false
;
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
...
...
@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
}
has_valid
=
true
;
epilogue
.
store
(
blk_coord
,
logical_problem_shape
,
params
.
epilogue
,
params
.
problem_shape
,
...
...
@@ -602,9 +617,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static_assert
(
NumWarpsEpilogue
<=
1
);
if
constexpr
(
NumWarpsEpilogue
==
1
)
{
if
(
has_valid
)
{
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
}
else
if
(
role
==
WarpRole
::
Empty
)
{
...
...
tests/test_fmha_sm100.py
View file @
2d291b0c
...
...
@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
offst_q
=
total_q
offst_kv
=
total_k
q1_with_buffer
=
torch
.
empty
(
total_q
+
total_q
,
h
,
d
,
device
=
device
,
dtype
=
dtype
)
k1_with_buffer
=
torch
.
empty
(
offst_kv
+
total_k
,
h_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v1_with_buffer
=
torch
.
empty
(
offst_kv
+
total_k
,
h_k
,
dv
,
device
=
device
,
dtype
=
dtype
)
q1_with_buffer
[
total_q
:]
=
q
k1_with_buffer
[
offst_kv
:]
=
k
v1_with_buffer
[
offst_kv
:]
=
v
q1
=
q1_with_buffer
[
offst_q
:].
requires_grad_
()
k1
=
k1_with_buffer
[
offst_kv
:].
requires_grad_
()
v1
=
v1_with_buffer
[
offst_kv
:].
requires_grad_
()
q1
=
q
.
clone
().
requires_grad_
()
k1
=
k
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment