Commit 3efdb593 authored by danyao12's avatar danyao12
Browse files

unpadded lse&d for group mode

parent 25db1339
......@@ -287,9 +287,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q});
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
......@@ -441,7 +441,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = max_seqlen_q;
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
......@@ -452,7 +452,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * max_seqlen_q);
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
......@@ -749,7 +749,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(wb, idx[0], idx[1]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on
q_host_refs.push_back(q_host_ref);
......
......@@ -187,7 +187,6 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
......@@ -278,8 +277,7 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
......
......@@ -155,8 +155,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t batch_stride_lsed;
};
struct FmhaBwdCommonBiasKargs
......@@ -246,6 +244,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
......@@ -361,17 +360,17 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
nhead_stride_dv}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk,
batch_stride_dv};
......@@ -467,7 +466,6 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
......@@ -506,13 +504,12 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_lsed,
nhead_stride_dq_acc,
nhead_stride_dk,
nhead_stride_dv,
batch_stride_lsed}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
nhead_stride_dv}, // args for common karg
{}, // placeholder for bias
{}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
......@@ -615,7 +612,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_lsed = query_start;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
......@@ -1142,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d;
ck_tile::index_t batch_stride_d;
};
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
};
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
......@@ -1186,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d,
batch_stride_d},
nhead_stride_d},
batch_stride_do,
batch_stride_o};
batch_stride_o,
batch_stride_d};
return kargs;
}
......@@ -1206,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d,
ck_tile::index_t batch_stride_d)
ck_tile::index_t nhead_stride_d)
{
Kargs kargs{{o_ptr,
do_ptr,
......@@ -1219,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o,
nhead_stride_do,
nhead_stride_o,
nhead_stride_d,
batch_stride_d},
nhead_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
return kargs;
......@@ -1263,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
batch_offset_d = query_start;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment