Unverified Commit 10cc12ba authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Feature/mla tests (#23195)


Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent a4fbb32f
...@@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache( ...@@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache(
# Permute the context blocks (excluding block 0 which is null) # Permute the context blocks (excluding block 0 which is null)
if randomize_blocks: if randomize_blocks:
perm = torch.randperm( # Random permutation starting from block 1
blocks_end - 1) + 1 # Random permutation starting from block 1 perm = torch.randperm(blocks_end - 1) + 1
else: else:
perm = torch.arange( # Sequential order starting from block 1
1, blocks_end) # Sequential order starting from block 1 perm = torch.arange(1, blocks_end)
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
inv_perm[1:] = torch.argsort( # Add 1 to account for starting from block 1
perm) + 1 # Add 1 to account for starting from block 1 inv_perm[1:] = torch.argsort(perm) + 1
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
# Construct the right block table # Construct the right block table
...@@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, ...@@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
@pytest.mark.parametrize("batch_spec_name", [ @pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode", "small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium" "medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
]) ])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str): def test_backend_correctness(batch_spec_name: str, model: str):
...@@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): ...@@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
""" """
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
vllm_config = create_vllm_config(model_name=model, vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens)) max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=8192)
device = torch.device("cuda:0") device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
...@@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): ...@@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
if not all_close:
print(f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
print(f"[{backend_name}] output: {backend_output}")
print(f"[{backend_name}] SDPA baseline: {sdpa_output}")
assert all_close, ( assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. " f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
\ No newline at end of file
This diff is collapsed.
...@@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): ...@@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1: _Backend.XFORMERS_VLLM_V1:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend", "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
} }
if backend_name not in backend_map: if backend_name not in backend_map:
...@@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
max_model_len: int = 1024, max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto", dtype: Union[ModelDType, torch.dtype] = "auto",
num_gpu_blocks: int = 1000,
block_size: int = 16, block_size: int = 16,
max_num_seqs: int = 256, max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192, max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True) -> VllmConfig: add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults.""" """Create a VllmConfig for testing with reasonable defaults."""
...@@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
) )
# Set cache blocks for testing # Set cache blocks for testing
# (these may be set during initialization normally) # (these may be set during initialization normally)
cache_config.num_gpu_blocks = 1000 cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0 cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
...@@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
) )
device_config = DeviceConfig() device_config = DeviceConfig()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment