Commit 3efdb593 authored by danyao12's avatar danyao12
Browse files

unpadded lse&d for group mode

parent 25db1339
...@@ -287,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -287,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host( ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host( ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}); std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host( ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
...@@ -441,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -441,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q; const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias = const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k); (i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments // setup batch_stride_* arguments
...@@ -452,7 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -452,7 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q); const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v); const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
...@@ -749,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -749,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); }); if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); }); else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); }); lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on // clang-format on
q_host_refs.push_back(q_host_ref); q_host_refs.push_back(q_host_ref);
......
...@@ -187,7 +187,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) ...@@ -187,7 +187,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_dk, args.nhead_stride_dk,
args.nhead_stride_dv, args.nhead_stride_dv,
args.nhead_stride_dbias, args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc, args.split_stride_dq_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
...@@ -278,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) ...@@ -278,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o, args.stride_o,
args.nhead_stride_do, args.nhead_stride_do,
args.nhead_stride_o, args.nhead_stride_o,
args.nhead_stride_lsed, args.nhead_stride_lsed);
args.batch_stride_lsed);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
......
...@@ -155,8 +155,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -155,8 +155,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dq_acc; ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk; ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv; ck_tile::index_t nhead_stride_dv;
ck_tile::index_t batch_stride_lsed;
}; };
struct FmhaBwdCommonBiasKargs struct FmhaBwdCommonBiasKargs
...@@ -246,6 +244,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -246,6 +244,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc; ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
...@@ -361,17 +360,17 @@ struct FmhaBwdDQDKDVKernel ...@@ -361,17 +360,17 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_dq_acc, nhead_stride_dq_acc,
nhead_stride_dk, nhead_stride_dk,
nhead_stride_dv, nhead_stride_dv}, // args for common karg
batch_stride_lsed}, // args for common karg {}, // placeholder for bias
{}, // placeholder for bias {}, // placeholder for dbias
{}, // placeholder for dbias {}, // placeholder for mask
{}, // placeholder for mask {}, // placeholder for dropout
{}, // placeholder for dropout {}, // placeholder for deterministic
{}, // placeholder for deterministic
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
batch_stride_do, batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc, batch_stride_dq_acc,
batch_stride_dk, batch_stride_dk,
batch_stride_dv}; batch_stride_dv};
...@@ -467,7 +466,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -467,7 +466,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dk, ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv, ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc, ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
...@@ -506,13 +504,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -506,13 +504,12 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed, nhead_stride_lsed,
nhead_stride_dq_acc, nhead_stride_dq_acc,
nhead_stride_dk, nhead_stride_dk,
nhead_stride_dv, nhead_stride_dv}, // args for common karg
batch_stride_lsed}, // args for common karg {}, // placeholder for bias
{}, // placeholder for bias {}, // placeholder for dbias
{}, // placeholder for dbias {}, // placeholder for mask
{}, // placeholder for mask {}, // placeholder for dropout
{}, // placeholder for dropout {}, // placeholder for deterministic
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
...@@ -615,7 +612,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -615,7 +612,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do; batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed; batch_offset_lsed = query_start;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc; batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv; batch_offset_dv = key_start * kargs.stride_dv;
...@@ -1142,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel ...@@ -1142,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d; ck_tile::index_t nhead_stride_d;
ck_tile::index_t batch_stride_d;
}; };
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{ {
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
}; };
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
...@@ -1186,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel ...@@ -1186,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o, stride_o,
nhead_stride_do, nhead_stride_do,
nhead_stride_o, nhead_stride_o,
nhead_stride_d, nhead_stride_d},
batch_stride_d},
batch_stride_do, batch_stride_do,
batch_stride_o}; batch_stride_o,
batch_stride_d};
return kargs; return kargs;
} }
...@@ -1206,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1206,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t stride_o, ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d, ck_tile::index_t nhead_stride_d)
ck_tile::index_t batch_stride_d)
{ {
Kargs kargs{{o_ptr, Kargs kargs{{o_ptr,
do_ptr, do_ptr,
...@@ -1219,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1219,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o, stride_o,
nhead_stride_do, nhead_stride_do,
nhead_stride_o, nhead_stride_o,
nhead_stride_d, nhead_stride_d},
batch_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)}; reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
return kargs; return kargs;
...@@ -1263,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1263,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o = query_start * kargs.stride_o; batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do; batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d; batch_offset_d = query_start;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment