Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2d291b0c
Unverified
Commit
2d291b0c
authored
Aug 25, 2025
by
zhang
Committed by
GitHub
Aug 25, 2025
Browse files
Remove tma padding for fwd inputs (#85)
parent
c7590278
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
96 deletions
+68
-96
csrc/sm100/collective/fmha_fusion.hpp
csrc/sm100/collective/fmha_fusion.hpp
+3
-3
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
.../sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
+16
-33
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
...00/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
+16
-33
csrc/sm100/fmha_cutlass_fwd_sm100.cu
csrc/sm100/fmha_cutlass_fwd_sm100.cu
+3
-2
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
+4
-7
csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
...m100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
+23
-6
tests/test_fmha_sm100.py
tests/test_fmha_sm100.py
+3
-12
No files found.
csrc/sm100/collective/fmha_fusion.hpp
View file @
2d291b0c
...
@@ -220,13 +220,13 @@ struct CausalMask : NoMask {
...
@@ -220,13 +220,13 @@ struct CausalMask : NoMask {
BlkCoord
const
&
blk_coord
,
BlkCoord
const
&
blk_coord
,
TileShape
const
&
tile_shape
,
TileShape
const
&
tile_shape
,
ProblemSize
const
&
problem_size
)
{
ProblemSize
const
&
problem_size
)
{
int
trip_count
=
get_trip_count
(
blk_coord
,
tile_shape
,
problem_size
);
int
trip_count
=
get_trip_count
(
blk_coord
,
tile_shape
,
problem_size
);
if
constexpr
(
IsQBegin
)
{
if
constexpr
(
IsQBegin
)
{
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
size
<
0
>
(
tile_shape
),
size
<
1
>
(
tile_shape
))));
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
size
<
0
>
(
tile_shape
),
size
<
1
>
(
tile_shape
))));
}
else
{
}
else
{
const
int
offset_tile_q
=
get
<
1
>
(
problem_size
)
%
get
<
1
>
(
tile_shape
);
const
int
corner_count
=
int
((
get
<
1
>
(
problem_size
)
%
get
<
1
>
(
tile_shape
)
||
get
<
0
>
(
problem_size
)
%
get
<
0
>
(
tile_shape
)
))
;
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
get
<
0
>
(
tile_shape
)
+
offset_tile_q
,
get
<
1
>
(
tile_shape
))));
return
std
::
min
(
trip_count
,
int
(
ceil_div
(
get
<
0
>
(
tile_shape
),
get
<
1
>
(
tile_shape
)))
+
corner_count
);
}
}
}
}
...
...
csrc/sm100/collective/sm100_fmha_load_tma_warpspecialized.hpp
View file @
2d291b0c
...
@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
...
@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto
dQ
=
args
.
dQ
;
auto
dQ
=
args
.
dQ
;
auto
dK
=
args
.
dK
;
auto
dK
=
args
.
dK
;
auto
dV
=
args
.
dV
;
auto
dV
=
args
.
dV
;
auto
problem_shape_qk
=
problem_shape
;
using
IntProblemShape
=
cute
::
tuple
<
int
,
int
,
int
,
cute
::
tuple
<
cute
::
tuple
<
int
,
int
>
,
int
>>
;
IntProblemShape
problem_shape_qk
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
auto
cumulative_length_k
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
if
(
cumulative_length_q
!=
nullptr
&&
cumulative_length_k
!=
nullptr
)
{
// for variable sequence lenght, the batch is in units of row_stride
get
<
0
>
(
problem_shape_qk
)
=
get
<
0
>
(
problem_shape
).
total_length
;
get
<
2
,
1
>
(
dQ
)
=
get
<
0
>
(
dQ
);
get
<
1
>
(
problem_shape_qk
)
=
get
<
1
>
(
problem_shape
).
total_length
;
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_q
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
get
<
2
>
(
problem_shape_qk
)
=
get
<
2
>
(
problem_shape
);
// offset ptr by the amount we add back in later
get
<
3
>
(
problem_shape_qk
)
=
get
<
3
>
(
problem_shape
);
ptr_Q
-=
max_length_q
*
get
<
0
>
(
dQ
);
}
}
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ProblemShape
>>
)
{
auto
cumulative_length_kv
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_kv
!=
nullptr
)
{
int
max_length_kv
=
get
<
1
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dK
)
=
get
<
0
>
(
dK
);
get
<
2
,
1
>
(
dV
)
=
get
<
0
>
(
dV
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_kv
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_K
-=
max_length_kv
*
get
<
0
>
(
dK
);
ptr_V
-=
max_length_kv
*
get
<
0
>
(
dV
);
}
}
}
else
{
problem_shape_qk
=
problem_shape
;
}
}
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
auto
params_qk
=
CollectiveMmaQK
::
to_underlying_arguments
(
...
@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
...
@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape
));
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape
));
int
q_offs_0
=
0
;
int
q_offs_0
=
0
;
int
q_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
params_problem_shape
).
max_length
;
q_offs_0
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)];
q_offs_0
=
max_length_q
-
get
<
0
>
(
problem_shape
);
q_offs_2_1
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)]
+
get
<
0
>
(
problem_shape
);
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
}
}
}
}
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
q_offs_2_1
)),
mQ_qdl_p
);
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mQ_qdl_p
);
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
...
@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
...
@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape
));
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape
));
int
kv_offs_0
=
0
;
int
kv_offs_0
=
0
;
int
kv_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length
!=
nullptr
)
{
if
(
cumulative_length
!=
nullptr
)
{
int
max_length
=
get
<
1
>
(
params_problem_shape
).
max_length
;
kv_offs_0
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)];
kv_offs_0
=
max_length
-
get
<
1
>
(
problem_shape
);
kv_offs_2_1
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)]
+
get
<
1
>
(
problem_shape
);
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
}
}
}
}
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
kv_offs_2_1
)),
mK_kdl_p
);
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mK_kdl_p
);
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
...
@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
...
@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape
));
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape
));
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
kv_offs_2_1
)),
mV_dkl_p
);
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
_0
{}
)),
mV_dkl_p
);
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
...
...
csrc/sm100/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp
View file @
2d291b0c
...
@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
...
@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
auto
dQ
=
args
.
dQ
;
auto
dQ
=
args
.
dQ
;
auto
dK
=
args
.
dK
;
auto
dK
=
args
.
dK
;
auto
dV
=
args
.
dV
;
auto
dV
=
args
.
dV
;
auto
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));
using
IntProblemShape
=
cute
::
tuple
<
int
,
int
,
int
,
cute
::
tuple
<
cute
::
tuple
<
int
,
int
>
,
int
>>
;
IntProblemShape
problem_shape_qk
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
auto
cumulative_length_q
=
get
<
0
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
auto
cumulative_length_k
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
int
max_length_q
=
get
<
0
>
(
problem_shape
).
max_length
;
if
(
cumulative_length_q
!=
nullptr
&&
cumulative_length_k
!=
nullptr
)
{
// for variable sequence lenght, the batch is in units of row_stride
get
<
0
>
(
problem_shape_qk
)
=
get
<
0
>
(
problem_shape
).
total_length
;
get
<
2
,
1
>
(
dQ
)
=
get
<
0
>
(
dQ
);
get
<
1
>
(
problem_shape_qk
)
=
get
<
1
>
(
problem_shape
).
total_length
;
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_q
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
get
<
2
>
(
problem_shape_qk
)
=
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
);
// offset ptr by the amount we add back in later
get
<
3
>
(
problem_shape_qk
)
=
get
<
3
>
(
problem_shape
);
ptr_Q
-=
max_length_q
*
get
<
0
>
(
dQ
);
}
}
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ProblemShape
>>
)
{
auto
cumulative_length_kv
=
get
<
1
>
(
problem_shape
).
cumulative_length
;
if
(
cumulative_length_kv
!=
nullptr
)
{
int
max_length_kv
=
get
<
1
>
(
problem_shape
).
max_length
;
// for variable sequence lenght, the batch is in units of row_stride
get
<
2
,
1
>
(
dK
)
=
get
<
0
>
(
dK
);
get
<
2
,
1
>
(
dV
)
=
get
<
0
>
(
dV
);
get
<
3
,
1
>
(
problem_shape_qk
)
=
std
::
max
(
get
<
3
,
1
>
(
problem_shape_qk
),
max_length_kv
*
(
1
+
get
<
3
,
1
>
(
problem_shape
)));
// offset ptr by the amount we add back in later
ptr_K
-=
max_length_kv
*
get
<
0
>
(
dK
);
ptr_V
-=
max_length_kv
*
get
<
0
>
(
dV
);
}
}
}
else
{
problem_shape_qk
=
replace
<
2
>
(
problem_shape
,
get
<
2
,
0
>
(
problem_shape
)
+
get
<
2
,
1
>
(
problem_shape
));;
}
}
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
auto
problem_shape_pv
=
replace
<
1
>
(
select
<
0
,
2
,
1
,
3
>
(
problem_shape_qk
),
get
<
2
,
0
>
(
problem_shape
));
...
@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
...
@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape_qk
));
Tensor
mQ_qdl_p
=
params
.
tma_load_q
.
get_tma_tensor
(
select
<
0
,
2
,
3
>
(
problem_shape_qk
));
int
q_offs_0
=
0
;
int
q_offs_0
=
0
;
int
q_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
0
,
ParamsProblemShape
>>
)
{
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
auto
cumulative_length_q
=
get
<
0
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length_q
!=
nullptr
)
{
if
(
cumulative_length_q
!=
nullptr
)
{
int
max_length_q
=
get
<
0
>
(
params_problem_shape
).
max_length
;
q_offs_0
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)];
q_offs_0
=
max_length_q
-
get
<
0
>
(
problem_shape
);
q_offs_2_1
=
cumulative_length_q
[
get
<
2
,
1
>
(
blk_coord_q
)]
+
get
<
0
>
(
problem_shape
);
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
get
<
2
,
1
>
(
blk_coord_q
)
=
0
;
}
}
}
}
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
q_offs_2_1
)),
mQ_qdl_p
);
Tensor
mQ_qdl
=
domain_offset
(
make_coord
(
q_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mQ_qdl_p
);
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
gQ_qdl
=
local_tile
(
mQ_qdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
Tensor
tSgQ_qdl
=
mma_qk
.
partition_A
(
gQ_qdl
);
...
@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
...
@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape_qk
));
Tensor
mK_kdl_p
=
params
.
tma_load_k
.
get_tma_tensor
(
select
<
1
,
2
,
3
>
(
problem_shape_qk
));
int
kv_offs_0
=
0
;
int
kv_offs_0
=
0
;
int
kv_offs_2_1
=
0
;
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
if
constexpr
(
is_variable_length_v
<
tuple_element_t
<
1
,
ParamsProblemShape
>>
)
{
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
auto
cumulative_length
=
get
<
1
>
(
params_problem_shape
).
cumulative_length
;
if
(
cumulative_length
!=
nullptr
)
{
if
(
cumulative_length
!=
nullptr
)
{
int
max_length
=
get
<
1
>
(
params_problem_shape
).
max_length
;
kv_offs_0
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)];
kv_offs_0
=
max_length
-
get
<
1
>
(
problem_shape
);
kv_offs_2_1
=
cumulative_length
[
get
<
2
,
1
>
(
blk_coord_kv
)]
+
get
<
1
>
(
problem_shape
);
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
get
<
2
,
1
>
(
blk_coord_kv
)
=
0
;
}
}
}
}
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
kv_offs_2_1
)),
mK_kdl_p
);
Tensor
mK_kdl
=
domain_offset
(
make_coord
(
kv_offs_0
,
_0
{},
make_coord
(
_0
{},
_0
{}
)),
mK_kdl_p
);
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
gK_kdl
=
local_tile
(
mK_kdl
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
Tensor
tSgK_kdl
=
mma_qk
.
partition_B
(
gK_kdl
);
...
@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
...
@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
ThrMMA
mma_pv
=
typename
CollectiveMmaPV
::
TiledMma
{}.
get_slice
(
0
);
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape_v
));
Tensor
mV_dkl_p
=
params
.
tma_load_v
.
get_tma_tensor
(
select
<
2
,
1
,
3
>
(
problem_shape_v
));
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
kv_offs_2_1
)),
mV_dkl_p
);
Tensor
mV_dkl
=
domain_offset
(
make_coord
(
_0
{},
kv_offs_0
,
make_coord
(
_0
{},
_0
{}
)),
mV_dkl_p
);
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
gV_dkl
=
local_tile
(
mV_dkl
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
Tensor
tOgV_dkl
=
mma_pv
.
partition_B
(
gV_dkl
);
...
...
csrc/sm100/fmha_cutlass_fwd_sm100.cu
View file @
2d291b0c
...
@@ -18,8 +18,9 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va
...
@@ -18,8 +18,9 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va
static
constexpr
bool
IsVarlen
=
std
::
is_same_v
<
Varlen
,
true_type
>
;
static
constexpr
bool
IsVarlen
=
std
::
is_same_v
<
Varlen
,
true_type
>
;
static
constexpr
bool
IsMla
=
std
::
is_same_v
<
Mla
,
true_type
>
;
static
constexpr
bool
IsMla
=
std
::
is_same_v
<
Mla
,
true_type
>
;
static
constexpr
bool
IsCausalMask
=
std
::
is_same_v
<
Mask
,
CausalMask
<
false
>>
;
static
constexpr
bool
IsCausalMask
=
std
::
is_same_v
<
Mask
,
CausalMask
<
false
>>
;
using
Option
=
std
::
conditional_t
<
IsCausalMask
,
Option
<
Tag
::
kIsPersistent
,
false_type
>
,
using
Option
=
Option
<
Tag
::
kIsPersistent
,
true_type
>>
;
std
::
conditional_t
<
IsCausalMask
||
(
IsVarlen
),
Option
<
Tag
::
kIsPersistent
,
false_type
>
,
Option
<
Tag
::
kIsPersistent
,
true_type
>>
;
run_fmha_fwd
<
Element
,
ElementOut
,
IsVarlen
,
IsMla
,
Mask
,
Option
>
(
run_fmha_fwd
<
Element
,
ElementOut
,
IsVarlen
,
IsMla
,
Mask
,
Option
>
(
workspace_buffer
,
q
,
k
,
v
,
cumulative_seqlen_q
,
cumulative_seqlen_kv
,
o
,
lse
,
workspace_buffer
,
q
,
k
,
v
,
cumulative_seqlen_q
,
cumulative_seqlen_kv
,
o
,
lse
,
...
...
csrc/sm100/fmha_cutlass_fwd_sm100.cuh
View file @
2d291b0c
...
@@ -143,8 +143,8 @@ struct FwdRunner {
...
@@ -143,8 +143,8 @@ struct FwdRunner {
ProblemShapeType
problem_size_for_launch
;
ProblemShapeType
problem_size_for_launch
;
get
<
0
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_q
};
get
<
0
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_q
,
nullptr
,
total_seqlen_q
};
get
<
1
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_kv
};
get
<
1
>
(
problem_size_for_launch
)
=
VariableLength
{
max_seqlen_kv
,
nullptr
,
total_seqlen_kv
};
get
<
2
>
(
problem_size_for_launch
)
=
get
<
2
>
(
problem_size
);
get
<
2
>
(
problem_size_for_launch
)
=
get
<
2
>
(
problem_size
);
get
<
3
>
(
problem_size_for_launch
)
=
get
<
3
>
(
problem_size
);
get
<
3
>
(
problem_size_for_launch
)
=
get
<
3
>
(
problem_size
);
...
@@ -206,10 +206,6 @@ struct FwdRunner {
...
@@ -206,10 +206,6 @@ struct FwdRunner {
void
*
q_ptr
,
void
*
k_ptr
,
void
*
v_ptr
,
void
*
o_ptr
,
void
*
lse_ptr
,
void
*
q_ptr
,
void
*
k_ptr
,
void
*
v_ptr
,
void
*
o_ptr
,
void
*
lse_ptr
,
void
*
cumulative_length_q
,
void
*
cumulative_length_kv
)
{
void
*
cumulative_length_q
,
void
*
cumulative_length_kv
)
{
auto
problem_shape_
=
problem_shape
;
auto
problem_shape_
=
problem_shape
;
if
constexpr
(
kIsVarlen
)
{
get
<
0
>
(
problem_shape_
).
cumulative_length
=
static_cast
<
int
*>
(
cumulative_length_q
);
get
<
1
>
(
problem_shape_
).
cumulative_length
=
static_cast
<
int
*>
(
cumulative_length_kv
);
}
typename
Operation
::
Arguments
arguments
{
typename
Operation
::
Arguments
arguments
{
problem_shape_
,
problem_shape_
,
...
@@ -230,6 +226,7 @@ struct FwdRunner {
...
@@ -230,6 +226,7 @@ struct FwdRunner {
int
total_seqlen_q
=
q
.
size
(
0
);
int
total_seqlen_q
=
q
.
size
(
0
);
int
total_seqlen_kv
=
k
.
size
(
0
);
int
total_seqlen_kv
=
k
.
size
(
0
);
ProblemShapeType
problem_shape
=
ProblemShapeType
problem_shape
=
initialize
(
options
,
max_seqlen_q
,
max_seqlen_kv
,
total_seqlen_q
,
total_seqlen_kv
,
initialize
(
options
,
max_seqlen_q
,
max_seqlen_kv
,
total_seqlen_q
,
total_seqlen_kv
,
cumulative_seqlen_q
.
data_ptr
(),
cumulative_seqlen_kv
.
data_ptr
());
cumulative_seqlen_q
.
data_ptr
(),
cumulative_seqlen_kv
.
data_ptr
());
...
@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v
...
@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v
auto
options
=
get_options
();
auto
options
=
get_options
();
if
(
options
.
h
%
cutlass
::
fmha
::
kernel
::
CausalIndividualTileScheduler
::
TileH
==
0
&&
if
(
options
.
h
%
cutlass
::
fmha
::
kernel
::
CausalIndividualTileScheduler
::
TileH
==
0
&&
(
!
std
::
is_same_v
<
ActiveMask
,
NoMask
>
))
{
(
std
::
is_same_v
<
ActiveMask
,
CausalMask
<
false
>>
||
std
::
is_same_v
<
ActiveMask
,
CausalMask
<
true
>
>
))
{
FwdRunner
<
kIsMla
,
true
,
kIsVarlen
,
DTypeIn
,
DTypeOut
,
ActiveMask
,
KernelOptions
...
>
runner
;
FwdRunner
<
kIsMla
,
true
,
kIsVarlen
,
DTypeIn
,
DTypeOut
,
ActiveMask
,
KernelOptions
...
>
runner
;
runner
.
run
(
options
,
hw_info
,
q
,
k
,
v
,
o
,
lse
,
scale_softmax
,
workspace
,
cumulative_seqlen_q
,
runner
.
run
(
options
,
hw_info
,
q
,
k
,
v
,
o
,
lse
,
scale_softmax
,
workspace
,
cumulative_seqlen_q
,
cumulative_seqlen_kv
,
max_seqlen_q
,
max_seqlen_kv
);
cumulative_seqlen_kv
,
max_seqlen_q
,
max_seqlen_kv
);
...
...
csrc/sm100/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp
View file @
2d291b0c
...
@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else
if
(
role
==
WarpRole
::
Correction
)
{
else
if
(
role
==
WarpRole
::
Correction
)
{
cutlass
::
arch
::
warpgroup_reg_dealloc
<
NumRegsCorrection
>
();
cutlass
::
arch
::
warpgroup_reg_dealloc
<
NumRegsCorrection
>
();
bool
has_valid
=
false
;
CUTLASS_PRAGMA_NO_UNROLL
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
...
@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
continue
;
}
}
has_valid
=
true
;
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
mainloop
.
correction_empty
(
mainloop
.
correction_empty
(
blk_coord
,
blk_coord
,
...
@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
if
constexpr
(
NumWarpsEpilogue
==
0
)
{
if
constexpr
(
NumWarpsEpilogue
==
0
)
{
static_assert
(
NumWarpsCorrection
==
1
);
static_assert
(
NumWarpsCorrection
==
1
);
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
if
(
has_valid
)
{
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
}
}
}
else
if
(
role
==
WarpRole
::
MMA
)
{
else
if
(
role
==
WarpRole
::
MMA
)
{
warpgroup_reg_set
<
NumRegsOther
>
();
warpgroup_reg_set
<
NumRegsOther
>
();
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
bool
allocated
=
false
;
__syncwarp
();
CUTLASS_PRAGMA_NO_UNROLL
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
...
@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
continue
;
}
}
if
(
!
allocated
)
{
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
allocated
=
true
;
}
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
if
(
get
<
1
>
(
logical_problem_shape
)
==
0
)
{
continue
;
continue
;
}
}
...
@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else
if
(
role
==
WarpRole
::
Epilogue
)
{
else
if
(
role
==
WarpRole
::
Epilogue
)
{
warpgroup_reg_set
<
NumRegsOther
>
();
warpgroup_reg_set
<
NumRegsOther
>
();
bool
has_valid
=
false
;
CUTLASS_PRAGMA_NO_UNROLL
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
...
@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue
;
continue
;
}
}
has_valid
=
true
;
epilogue
.
store
(
epilogue
.
store
(
blk_coord
,
logical_problem_shape
,
blk_coord
,
logical_problem_shape
,
params
.
epilogue
,
params
.
problem_shape
,
params
.
epilogue
,
params
.
problem_shape
,
...
@@ -602,8 +617,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
...
@@ -602,8 +617,10 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static_assert
(
NumWarpsEpilogue
<=
1
);
static_assert
(
NumWarpsEpilogue
<=
1
);
if
constexpr
(
NumWarpsEpilogue
==
1
)
{
if
constexpr
(
NumWarpsEpilogue
==
1
)
{
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
if
(
has_valid
)
{
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
}
}
}
...
...
tests/test_fmha_sm100.py
View file @
2d291b0c
...
@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
...
@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
grad_out
=
torch
.
randn
(
total_q
,
h
,
dv
)
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
softmax_scale
=
(
d
+
100
)
**
(
-
0.5
)
offst_q
=
total_q
q1
=
q
.
clone
().
requires_grad_
()
offst_kv
=
total_k
k1
=
k
.
clone
().
requires_grad_
()
v1
=
v
.
clone
().
requires_grad_
()
q1_with_buffer
=
torch
.
empty
(
total_q
+
total_q
,
h
,
d
,
device
=
device
,
dtype
=
dtype
)
k1_with_buffer
=
torch
.
empty
(
offst_kv
+
total_k
,
h_k
,
d
,
device
=
device
,
dtype
=
dtype
)
v1_with_buffer
=
torch
.
empty
(
offst_kv
+
total_k
,
h_k
,
dv
,
device
=
device
,
dtype
=
dtype
)
q1_with_buffer
[
total_q
:]
=
q
k1_with_buffer
[
offst_kv
:]
=
k
v1_with_buffer
[
offst_kv
:]
=
v
q1
=
q1_with_buffer
[
offst_q
:].
requires_grad_
()
k1
=
k1_with_buffer
[
offst_kv
:].
requires_grad_
()
v1
=
v1_with_buffer
[
offst_kv
:].
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
q2
=
q
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
k2
=
k
.
clone
().
requires_grad_
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment