Commit 0739bc5a authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use kv_perm to controol key/value layout

parent 7c0e5822
...@@ -580,12 +580,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -580,12 +580,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back() : (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back())); : seqstart_k_with_padding_host.back()));
bool kv_perm = (mode == mode_enum::group && 0 < page_block_size ? true : i_perm);
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host( ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size 0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) ? get_lengths(kv_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); : get_lengths(kv_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host( ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew 0 < seqlen_knew
...@@ -594,10 +596,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -594,10 +596,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<VDataType> v_host( ck_tile::HostTensor<VDataType> v_host(
0 < page_block_size 0 < page_block_size
? (is_v_rowmajor ? (is_v_rowmajor
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) ? get_lengths(kv_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) : get_lengths(kv_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : (is_v_rowmajor ? get_lengths(kv_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); : get_lengths(kv_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<VDataType> vnew_host( ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew 0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
...@@ -762,9 +764,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -762,9 +764,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(mode == mode_enum::group && 0 < page_block_size) if(mode == mode_enum::group && 0 < page_block_size)
{ {
if(!(i_perm && !is_v_rowmajor)) if(!is_v_rowmajor)
{ {
std::cerr << "make sure input layout is correct" << std::endl; std::cerr << "make sure input layout is correct: -vlayout=r" << std::endl;
return false; return false;
} }
...@@ -877,14 +879,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -877,14 +879,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
/// 'nhead_stride_bias' are 0. /// 'nhead_stride_bias' are 0.
// setup stride_* arguments // setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); const ck_tile::index_t stride_k = (kv_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() { const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v; return kv_perm ? hdim_v : nhead_k * hdim_v;
else else
return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) return 0 < page_block_size ? (kv_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); : (kv_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}(); }();
const ck_tile::index_t stride_vnew = [&]() { const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
...@@ -899,16 +901,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -899,16 +901,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = const ck_tile::index_t nhead_stride_k =
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) (0 < page_block_size ? (kv_perm ? page_block_size * hdim_q : hdim_q)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); : (kv_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() { const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) return 0 < page_block_size ? (kv_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v); : (kv_perm ? shape_seqlen_k * hdim_v : hdim_v);
else else
return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) return 0 < page_block_size ? (kv_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); : (kv_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}(); }();
const ck_tile::index_t nhead_stride_vnew = [&]() { const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
...@@ -1278,7 +1280,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1278,7 +1280,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
#endif #endif
#if CK_TILE_FMHA_FWD_SPLITKV_API #if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) { if(0 < page_block_size) {
if(i_perm) { if(kv_perm) {
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
}); });
...@@ -1290,7 +1292,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1290,7 +1292,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} else } else
#endif #endif
{ {
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); if(kv_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
} }
...@@ -1330,7 +1332,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1330,7 +1332,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if CK_TILE_FMHA_FWD_SPLITKV_API #if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) { if(0 < page_block_size) {
if(is_v_rowmajor) { if(is_v_rowmajor) {
if(i_perm) { if(kv_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
}); });
...@@ -1342,7 +1344,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1342,7 +1344,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
else else
{ {
if(i_perm) { if(kv_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
}); });
...@@ -1357,13 +1359,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1357,13 +1359,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
if(is_v_rowmajor) { if(is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); if(kv_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
} }
else else
{ {
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); if(kv_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment