"test/srt/vscode:/vscode.git/clone" did not exist on "24eaebeb4b43ca24c8bf9aaf8c9d0836487f07df"
Commit 86517ce4 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vllm paged-kcache layout to read blocks

parent 44828b7c
......@@ -635,12 +635,38 @@ struct FmhaFwdSplitKVKernel
}();
const auto make_k_dram = [&](const KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
auto k_dram_naive = [&] {
if constexpr(kIsPagedKV)
{
constexpr index_t vector_size = FmhaPipeline::kAlignmentK;
// (hdim_q/vector_size, seqlen_k, vector_size)
const auto view = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_q / vector_size, height, number<vector_size>{}),
make_tuple(height * vector_size, number<vector_size>{}, number<1>{}),
number<vector_size>{},
number<1>{});
// (seqlen_k, hdim_q)
return transform_tensor_view(
view,
make_tuple(make_pass_through_transform(height),
make_merge_transform(make_tuple(kargs.hdim_q / vector_size,
number<vector_size>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
// (seqlen_k, hdim_q)
return make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
}
}();
return pad_tensor_view(
k_dram_naive,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment