Commit 8241459e authored by Ye Wang's avatar Ye Wang
Browse files

change the layout to sbhd for debugging

parent 33ff338d
......@@ -78,7 +78,7 @@ auto create_args(int argc, char* argv[])
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
"if true, will be s*b*h*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
......@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be seqlen * batch * nhead * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
float scale_s = arg_parser.get_float("scale_s");
......@@ -443,7 +443,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) {
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
return std::array<ck_tile::index_t, 4>{s, b, h, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
......@@ -592,7 +592,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
auto layout_str = [&](bool permute){
if (permute) return std::string("bhsd");
if (permute) return std::string("sbhd");
else return std::string("bshd");
};
auto io_layout = [&](bool iperm_, bool operm_) {
......@@ -647,24 +647,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
// bshd vs sbhd (perm)
const ck_tile::index_t stride_q = (i_perm ? batch * nhead * hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? batch * nhead_k* hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
return i_perm ? batch * nhead_k * hdim_v : nhead_k * hdim_v;
else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_o = (o_perm ? batch * nhead * hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_q = hdim_q;
const ck_tile::index_t nhead_stride_k = hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
return hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
......@@ -674,17 +675,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_o = hdim_v;
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_q = i_perm? nhead * hdim_q : nhead * shape_seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k = i_perm? nhead_k * hdim_q : nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = i_perm? nhead_k * hdim_v : nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = o_perm? nhead * hdim_v : nhead * shape_seqlen_q * hdim_v;
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
......@@ -814,16 +815,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nr = nhead / nhead_k;
// clang-format off
// permute
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
// permute bshd vs sbhd (perm)
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(i[1] + query_offset, b, i[0], i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(i[1] + key_offset, b, i[0] / nr, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
if (is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(i[2] + key_offset, b, i[0] / nr, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
}
......@@ -976,7 +977,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off
// permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); });
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(idx[1] + query_offset, b, idx[0], idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment