Unverified Commit 2d291b0c authored by zhang's avatar zhang Committed by GitHub
Browse files

Remove tma padding for fwd inputs (#85)

parent c7590278
...@@ -225,8 +225,8 @@ struct CausalMask : NoMask { ...@@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if constexpr (IsQBegin) { if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else { } else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape); const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
} }
} }
......
...@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized { ...@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto dQ = args.dQ; auto dQ = args.dQ;
auto dK = args.dK; auto dK = args.dK;
auto dV = args.dV; auto dV = args.dV;
auto problem_shape_qk = problem_shape;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length; auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) { auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
int max_length_q = get<0>(problem_shape).max_length; if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
// for variable sequence lenght, the batch is in units of row_stride get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<2,1>(dQ) = get<0>(dQ); get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); get<2>(problem_shape_qk) = get<2>(problem_shape);
// offset ptr by the amount we add back in later get<3>(problem_shape_qk) = get<3>(problem_shape);
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
} }
} else {
problem_shape_qk = problem_shape;
} }
auto params_qk = CollectiveMmaQK::to_underlying_arguments( auto params_qk = CollectiveMmaQK::to_underlying_arguments(
...@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { ...@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0; int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) { if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length; q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
get<2,1>(blk_coord_q) = 0; get<2,1>(blk_coord_q) = 0;
} }
} }
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
...@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized { ...@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0; int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length; auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) { if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length; kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
get<2,1>(blk_coord_kv) = 0; get<2,1>(blk_coord_kv) = 0;
} }
} }
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{}); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
...@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized { ...@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{}); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
......
...@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ...@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
auto dQ = args.dQ; auto dQ = args.dQ;
auto dK = args.dK; auto dK = args.dK;
auto dV = args.dV; auto dV = args.dV;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length; auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) { auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
int max_length_q = get<0>(problem_shape).max_length; if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
// for variable sequence lenght, the batch is in units of row_stride get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<2,1>(dQ) = get<0>(dQ); get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape))); get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
// offset ptr by the amount we add back in later get<3>(problem_shape_qk) = get<3>(problem_shape);
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
} }
} else {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
} }
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
...@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ...@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0; int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) { if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length; q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
get<2,1>(blk_coord_q) = 0; get<2,1>(blk_coord_q) = 0;
} }
} }
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
...@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ...@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0; int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) { if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length; auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) { if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length; kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
get<2,1>(blk_coord_kv) = 0; get<2,1>(blk_coord_kv) = 0;
} }
} }
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p); Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{}); Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
...@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ...@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p); Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{}); Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
......
...@@ -18,7 +18,8 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va ...@@ -18,7 +18,8 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>; static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>; static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>; static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>;
using Option = std::conditional_t<IsCausalMask, Option<Tag::kIsPersistent, false_type>, using Option =
std::conditional_t<IsCausalMask || (IsVarlen), Option<Tag::kIsPersistent, false_type>,
Option<Tag::kIsPersistent, true_type>>; Option<Tag::kIsPersistent, true_type>>;
run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>( run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>(
......
...@@ -143,8 +143,8 @@ struct FwdRunner { ...@@ -143,8 +143,8 @@ struct FwdRunner {
ProblemShapeType problem_size_for_launch; ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q}; get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv}; get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size); get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size); get<3>(problem_size_for_launch) = get<3>(problem_size);
...@@ -206,10 +206,6 @@ struct FwdRunner { ...@@ -206,10 +206,6 @@ struct FwdRunner {
void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr, void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr,
void *cumulative_length_q, void *cumulative_length_kv) { void *cumulative_length_q, void *cumulative_length_kv) {
auto problem_shape_ = problem_shape; auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_q);
get<1>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_kv);
}
typename Operation::Arguments arguments{ typename Operation::Arguments arguments{
problem_shape_, problem_shape_,
...@@ -230,6 +226,7 @@ struct FwdRunner { ...@@ -230,6 +226,7 @@ struct FwdRunner {
int total_seqlen_q = q.size(0); int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0); int total_seqlen_kv = k.size(0);
ProblemShapeType problem_shape = ProblemShapeType problem_shape =
initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv, initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv,
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr()); cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
...@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v ...@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v
auto options = get_options(); auto options = get_options();
if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 && if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 &&
(!std::is_same_v<ActiveMask, NoMask>)) { (std::is_same_v<ActiveMask, CausalMask<false>> || std::is_same_v<ActiveMask, CausalMask<true>>)) {
FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner; FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q, runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv); cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
......
...@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Correction) { else if (role == WarpRole::Correction) {
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>(); cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
...@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue; continue;
} }
has_valid = true;
if (get<1>(logical_problem_shape) == 0) { if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty( mainloop.correction_empty(
blk_coord, blk_coord,
...@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
if constexpr (NumWarpsEpilogue == 0) { if constexpr (NumWarpsEpilogue == 0) {
static_assert(NumWarpsCorrection == 1); static_assert(NumWarpsCorrection == 1);
if (has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} }
}
} }
else if (role == WarpRole::MMA) { else if (role == WarpRole::MMA) {
warpgroup_reg_set<NumRegsOther>(); warpgroup_reg_set<NumRegsOther>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); bool allocated = false;
__syncwarp();
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
...@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue; continue;
} }
if (!allocated) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
allocated = true;
}
if (get<1>(logical_problem_shape) == 0) { if (get<1>(logical_problem_shape) == 0) {
continue; continue;
} }
...@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Epilogue) { else if (role == WarpRole::Epilogue) {
warpgroup_reg_set<NumRegsOther>(); warpgroup_reg_set<NumRegsOther>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
...@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue; continue;
} }
has_valid = true;
epilogue.store( epilogue.store(
blk_coord, logical_problem_shape, blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape, params.epilogue, params.problem_shape,
...@@ -602,9 +617,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { ...@@ -602,9 +617,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static_assert(NumWarpsEpilogue <= 1); static_assert(NumWarpsEpilogue <= 1);
if constexpr (NumWarpsEpilogue == 1) { if constexpr (NumWarpsEpilogue == 1) {
if(has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} }
}
} }
else if (role == WarpRole::Empty) { else if (role == WarpRole::Empty) {
......
...@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win ...@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
grad_out = torch.randn(total_q, h, dv) grad_out = torch.randn(total_q, h, dv)
softmax_scale = (d + 100) ** (-0.5) softmax_scale = (d + 100) ** (-0.5)
offst_q = total_q q1 = q.clone().requires_grad_()
offst_kv = total_k k1 = k.clone().requires_grad_()
v1 = v.clone().requires_grad_()
q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype)
k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype)
v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype)
q1_with_buffer[total_q:] = q
k1_with_buffer[offst_kv:] = k
v1_with_buffer[offst_kv:] = v
q1 = q1_with_buffer[offst_q:].requires_grad_()
k1 = k1_with_buffer[offst_kv:].requires_grad_()
v1 = v1_with_buffer[offst_kv:].requires_grad_()
q2 = q.clone().requires_grad_() q2 = q.clone().requires_grad_()
k2 = k.clone().requires_grad_() k2 = k.clone().requires_grad_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment