Commit 8241459e authored by Ye Wang's avatar Ye Wang
Browse files

change the layout to sbhd for debugging

parent 33ff338d
...@@ -78,7 +78,7 @@ auto create_args(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ auto create_args(int argc, char* argv[])
.insert("iperm", .insert("iperm",
"1", "1",
"permute input\n" "permute input\n"
"if true, will be b*h*s*d, else b*s*h*d") "if true, will be s*b*h*d, else b*s*h*d")
.insert("operm", "1", "permute output") .insert("operm", "1", "permute output")
.insert("bias", .insert("bias",
"n", "n",
...@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -295,7 +295,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(hdim_v < 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be seqlen * batch * nhead * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
float scale_s = arg_parser.get_float("scale_s"); float scale_s = arg_parser.get_float("scale_s");
...@@ -443,7 +443,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -443,7 +443,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t s /*seqlen*/, ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) { ck_tile::index_t d /*hdim*/) {
if(permute) if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d}; return std::array<ck_tile::index_t, 4>{s, b, h, d};
else else
return std::array<ck_tile::index_t, 4>{b, s, h, d}; return std::array<ck_tile::index_t, 4>{b, s, h, d};
}; };
...@@ -592,7 +592,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -592,7 +592,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off // clang-format off
auto layout_str = [&](bool permute){ auto layout_str = [&](bool permute){
if (permute) return std::string("bhsd"); if (permute) return std::string("sbhd");
else return std::string("bshd"); else return std::string("bshd");
}; };
auto io_layout = [&](bool iperm_, bool operm_) { auto io_layout = [&](bool iperm_, bool operm_) {
...@@ -647,24 +647,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -647,24 +647,25 @@ bool run(const ck_tile::ArgParser& arg_parser)
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0. /// 'nhead_stride_bias' are 0.
// setup stride_* arguments // setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); // bshd vs sbhd (perm)
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); const ck_tile::index_t stride_q = (i_perm ? batch * nhead * hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? batch * nhead_k* hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() { const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v; return i_perm ? batch * nhead_k * hdim_v : nhead_k * hdim_v;
else else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v; const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? batch * nhead * hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = hdim_q;
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k = hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() { const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v; return hdim_v;
else else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}(); }();
...@@ -674,17 +675,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -674,17 +675,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = hdim_v;
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = i_perm? nhead * hdim_q : nhead * shape_seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_k = i_perm? nhead_k * hdim_q : nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); const ck_tile::index_t batch_stride_v = i_perm? nhead_k * hdim_v : nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = o_perm? nhead * hdim_v : nhead * shape_seqlen_q * hdim_v;
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
...@@ -814,16 +815,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -814,16 +815,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t nr = nhead / nhead_k; ck_tile::index_t nr = nhead / nhead_k;
// clang-format off // clang-format off
// permute // permute bshd vs sbhd (perm)
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(i[1] + query_offset, b, i[0], i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(i[1] + key_offset, b, i[0] / nr, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
if (is_v_rowmajor) { if (is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(i[2] + key_offset, b, i[0] / nr, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
} }
...@@ -976,7 +977,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -976,7 +977,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v}); ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off // clang-format off
// permute // permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(idx[1] + query_offset, b, idx[0], idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on // clang-format on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment