Commit ead9c3cb authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Fix wrong vector size & stride

parent 1f12c4e0
...@@ -638,12 +638,14 @@ struct FmhaFwdSplitKVKernel ...@@ -638,12 +638,14 @@ struct FmhaFwdSplitKVKernel
auto k_dram_naive = [&] { auto k_dram_naive = [&] {
if constexpr(kIsPagedKV) if constexpr(kIsPagedKV)
{ {
constexpr index_t vector_size = FmhaPipeline::kAlignmentK; constexpr index_t vector_size = 16 / sizeof(KDataType);
// (hdim_q/vector_size, page_block_size, vector_size) // (hdim_q/vector_size, page_block_size, vector_size)
const auto view = make_naive_tensor_view<address_space_enum::global>( const auto view = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_q / vector_size, height, number<vector_size>{}), make_tuple(kargs.hdim_q / vector_size, height, number<vector_size>{}),
make_tuple(height * vector_size, number<vector_size>{}, number<1>{}), make_tuple(kargs.page_block_size * vector_size,
number<vector_size>{},
number<1>{}),
number<vector_size>{}, number<vector_size>{},
number<1>{}); number<1>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment