Unverified Commit 2d291b0c authored by zhang's avatar zhang Committed by GitHub
Browse files

Remove tma padding for fwd inputs (#85)

parent c7590278
......@@ -225,8 +225,8 @@ struct CausalMask : NoMask {
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int offset_tile_q = get<1>(problem_size) % get<1>(tile_shape);
return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape))));
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}
}
......
......@@ -95,32 +95,21 @@ struct Sm100FmhaLoadTmaWarpspecialized {
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = problem_shape;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = problem_shape;
}
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
......@@ -181,19 +170,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
......@@ -208,19 +194,16 @@ struct Sm100FmhaLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
......@@ -235,7 +218,7 @@ struct Sm100FmhaLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
......
......@@ -102,32 +102,21 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dQ) = get<0>(dQ);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_q * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_Q -= max_length_q * get<0>(dQ);
}
}
if constexpr (is_variable_length_v<tuple_element_t<1, ProblemShape>>) {
auto cumulative_length_kv = get<1>(problem_shape).cumulative_length;
if (cumulative_length_kv != nullptr) {
int max_length_kv = get<1>(problem_shape).max_length;
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dK) = get<0>(dK);
get<2,1>(dV) = get<0>(dV);
get<3,1>(problem_shape_qk) = std::max(get<3,1>(problem_shape_qk), max_length_kv * (1 + get<3,1>(problem_shape)));
// offset ptr by the amount we add back in later
ptr_K -= max_length_kv * get<0>(dK);
ptr_V -= max_length_kv * get<0>(dV);
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
}
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
......@@ -192,19 +181,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
int q_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
q_offs_0 = max_length_q - get<0>(problem_shape);
q_offs_2_1 = cumulative_length_q[get<2,1>(blk_coord_q)] + get<0>(problem_shape);
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p);
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
......@@ -219,19 +205,16 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
int kv_offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
int max_length = get<1>(params_problem_shape).max_length;
kv_offs_0 = max_length - get<1>(problem_shape);
kv_offs_2_1 = cumulative_length[get<2,1>(blk_coord_kv)] + get<1>(problem_shape);
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, kv_offs_2_1)), mK_kdl_p);
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
......@@ -246,7 +229,7 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, kv_offs_2_1)), mV_dkl_p);
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
......
......@@ -18,7 +18,8 @@ void call_run_fmha_fwd([[maybe_unused]] Mask mask, [[maybe_unused]] Varlen is_va
static constexpr bool IsVarlen = std::is_same_v<Varlen, true_type>;
static constexpr bool IsMla = std::is_same_v<Mla, true_type>;
static constexpr bool IsCausalMask = std::is_same_v<Mask, CausalMask<false>>;
using Option = std::conditional_t<IsCausalMask, Option<Tag::kIsPersistent, false_type>,
using Option =
std::conditional_t<IsCausalMask || (IsVarlen), Option<Tag::kIsPersistent, false_type>,
Option<Tag::kIsPersistent, true_type>>;
run_fmha_fwd<Element, ElementOut, IsVarlen, IsMla, Mask, Option>(
......
......@@ -143,8 +143,8 @@ struct FwdRunner {
ProblemShapeType problem_size_for_launch;
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv};
get<0>(problem_size_for_launch) = VariableLength{max_seqlen_q, nullptr, total_seqlen_q};
get<1>(problem_size_for_launch) = VariableLength{max_seqlen_kv, nullptr, total_seqlen_kv};
get<2>(problem_size_for_launch) = get<2>(problem_size);
get<3>(problem_size_for_launch) = get<3>(problem_size);
......@@ -206,10 +206,6 @@ struct FwdRunner {
void *q_ptr, void *k_ptr, void *v_ptr, void *o_ptr, void *lse_ptr,
void *cumulative_length_q, void *cumulative_length_kv) {
auto problem_shape_ = problem_shape;
if constexpr (kIsVarlen) {
get<0>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_q);
get<1>(problem_shape_).cumulative_length = static_cast<int *>(cumulative_length_kv);
}
typename Operation::Arguments arguments{
problem_shape_,
......@@ -230,6 +226,7 @@ struct FwdRunner {
int total_seqlen_q = q.size(0);
int total_seqlen_kv = k.size(0);
ProblemShapeType problem_shape =
initialize(options, max_seqlen_q, max_seqlen_kv, total_seqlen_q, total_seqlen_kv,
cumulative_seqlen_q.data_ptr(), cumulative_seqlen_kv.data_ptr());
......@@ -322,7 +319,7 @@ void run_fmha_fwd(at::Tensor workspace, at::Tensor q, at::Tensor k, at::Tensor v
auto options = get_options();
if (options.h % cutlass::fmha::kernel::CausalIndividualTileScheduler::TileH == 0 &&
(!std::is_same_v<ActiveMask, NoMask>)) {
(std::is_same_v<ActiveMask, CausalMask<false>> || std::is_same_v<ActiveMask, CausalMask<true>>)) {
FwdRunner<kIsMla, true, kIsVarlen, DTypeIn, DTypeOut, ActiveMask, KernelOptions...> runner;
runner.run(options, hw_info, q, k, v, o, lse, scale_softmax, workspace, cumulative_seqlen_q,
cumulative_seqlen_kv, max_seqlen_q, max_seqlen_kv);
......
......@@ -465,6 +465,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Correction) {
cutlass::arch::warpgroup_reg_dealloc<NumRegsCorrection>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
......@@ -476,6 +478,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
has_valid = true;
if (get<1>(logical_problem_shape) == 0) {
mainloop.correction_empty(
blk_coord,
......@@ -505,16 +509,17 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
if constexpr (NumWarpsEpilogue == 0) {
static_assert(NumWarpsCorrection == 1);
if (has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
else if (role == WarpRole::MMA) {
warpgroup_reg_set<NumRegsOther>();
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
bool allocated = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
......@@ -527,6 +532,12 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
if (!allocated) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
allocated = true;
}
if (get<1>(logical_problem_shape) == 0) {
continue;
}
......@@ -580,6 +591,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
else if (role == WarpRole::Epilogue) {
warpgroup_reg_set<NumRegsOther>();
bool has_valid = false;
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
......@@ -591,6 +604,8 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
continue;
}
has_valid = true;
epilogue.store(
blk_coord, logical_problem_shape,
params.epilogue, params.problem_shape,
......@@ -602,9 +617,11 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
static_assert(NumWarpsEpilogue <= 1);
if constexpr (NumWarpsEpilogue == 1) {
if(has_valid) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
}
else if (role == WarpRole::Empty) {
......
......@@ -82,18 +82,9 @@ def test_flash_attention(b, mean_sq, mean_sk, varlen, h, h_k, d, dv, causal, win
grad_out = torch.randn(total_q, h, dv)
softmax_scale = (d + 100) ** (-0.5)
offst_q = total_q
offst_kv = total_k
q1_with_buffer = torch.empty(total_q + total_q, h, d, device=device, dtype=dtype)
k1_with_buffer = torch.empty(offst_kv + total_k, h_k, d, device=device, dtype=dtype)
v1_with_buffer = torch.empty(offst_kv + total_k, h_k, dv, device=device, dtype=dtype)
q1_with_buffer[total_q:] = q
k1_with_buffer[offst_kv:] = k
v1_with_buffer[offst_kv:] = v
q1 = q1_with_buffer[offst_q:].requires_grad_()
k1 = k1_with_buffer[offst_kv:].requires_grad_()
v1 = v1_with_buffer[offst_kv:].requires_grad_()
q1 = q.clone().requires_grad_()
k1 = k.clone().requires_grad_()
v1 = v.clone().requires_grad_()
q2 = q.clone().requires_grad_()
k2 = k.clone().requires_grad_()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment