Unverified Commit 0ab3f437 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Cutlass MLA: Disable split kv due to https://github.com/NVIDIA/cutlass/issues/2274 (#6101)

parent cec98f10
...@@ -151,7 +151,10 @@ typename T::Fmha::Arguments args_from_options( ...@@ -151,7 +151,10 @@ typename T::Fmha::Arguments args_from_options(
page_size}, page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE}, {static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info, hw_info,
-1, // split_kv // TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
1, // split_kv
nullptr, // is_var_split_kv nullptr, // is_var_split_kv
}; };
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
......
...@@ -67,7 +67,7 @@ def test_cutlass_mla_decode( ...@@ -67,7 +67,7 @@ def test_cutlass_mla_decode(
pack_factor = 128 // block_size pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
q = torch.randn(bs, h_q, d) q = torch.randn(bs, h_q, d) * 100.0
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
kv_cache = torch.randn(block_table.numel(), block_size, d) kv_cache = torch.randn(block_table.numel(), block_size, d)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment