Commit 86517ce4 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Use vllm paged-kcache layout to read blocks

parent 44828b7c
...@@ -635,12 +635,38 @@ struct FmhaFwdSplitKVKernel ...@@ -635,12 +635,38 @@ struct FmhaFwdSplitKVKernel
}(); }();
const auto make_k_dram = [&](const KDataType* data, index_t height) { const auto make_k_dram = [&](const KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>( auto k_dram_naive = [&] {
data, // will update this pointer if using paged-kvcache if constexpr(kIsPagedKV)
make_tuple(height, kargs.hdim_q), {
make_tuple(kargs.stride_k, 1), constexpr index_t vector_size = FmhaPipeline::kAlignmentK;
number<FmhaPipeline::kAlignmentK>{}, // (hdim_q/vector_size, seqlen_k, vector_size)
number<1>{}); const auto view = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_q / vector_size, height, number<vector_size>{}),
make_tuple(height * vector_size, number<vector_size>{}, number<1>{}),
number<vector_size>{},
number<1>{});
// (seqlen_k, hdim_q)
return transform_tensor_view(
view,
make_tuple(make_pass_through_transform(height),
make_merge_transform(make_tuple(kargs.hdim_q / vector_size,
number<vector_size>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
else
{
// (seqlen_k, hdim_q)
return make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
}
}();
return pad_tensor_view( return pad_tensor_view(
k_dram_naive, k_dram_naive,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment