Unverified Commit 10cc12ba authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Feature/mla tests (#23195)


Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent a4fbb32f
...@@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache( ...@@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache(
# Permute the context blocks (excluding block 0 which is null) # Permute the context blocks (excluding block 0 which is null)
if randomize_blocks: if randomize_blocks:
perm = torch.randperm( # Random permutation starting from block 1
blocks_end - 1) + 1 # Random permutation starting from block 1 perm = torch.randperm(blocks_end - 1) + 1
else: else:
perm = torch.arange( # Sequential order starting from block 1
1, blocks_end) # Sequential order starting from block 1 perm = torch.arange(1, blocks_end)
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
inv_perm[1:] = torch.argsort( # Add 1 to account for starting from block 1
perm) + 1 # Add 1 to account for starting from block 1 inv_perm[1:] = torch.argsort(perm) + 1
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
# Construct the right block table # Construct the right block table
...@@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, ...@@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
@pytest.mark.parametrize("batch_spec_name", [ @pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode", "small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium" "medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
]) ])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str): def test_backend_correctness(batch_spec_name: str, model: str):
...@@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): ...@@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
""" """
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
vllm_config = create_vllm_config(model_name=model, vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens)) max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=8192)
device = torch.device("cuda:0") device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
...@@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): ...@@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
if not all_close:
print(f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
print(f"[{backend_name}] output: {backend_output}")
print(f"[{backend_name}] SDPA baseline: {sdpa_output}")
assert all_close, ( assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. " f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
\ No newline at end of file
This diff is collapsed.
...@@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): ...@@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1: _Backend.XFORMERS_VLLM_V1:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend", "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
} }
if backend_name not in backend_map: if backend_name not in backend_map:
...@@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
max_model_len: int = 1024, max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto", dtype: Union[ModelDType, torch.dtype] = "auto",
num_gpu_blocks: int = 1000,
block_size: int = 16, block_size: int = 16,
max_num_seqs: int = 256, max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192, max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True) -> VllmConfig: add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults.""" """Create a VllmConfig for testing with reasonable defaults."""
...@@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
) )
# Set cache blocks for testing # Set cache blocks for testing
# (these may be set during initialization normally) # (these may be set during initialization normally)
cache_config.num_gpu_blocks = 1000 cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0 cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
...@@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ...@@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
) )
device_config = DeviceConfig() device_config = DeviceConfig()
......
...@@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation ...@@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way: Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache. * Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a * For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention. multi-head attention, while the compute is similar to multi-query attention.
...@@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( ...@@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v v
) )
return spda_o @ W_O return spda_o @ W_O
NOTE: in the actual code, NOTE: in the actual code,
...@@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ...@@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill ## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small. the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size. fixed workspace size.
The chunked prefill approach is as follows: The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically, MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage used to bound the memory usage
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
...@@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( ...@@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
new_v, new_v,
casual=True, casual=True,
return_softmax_lse=True return_softmax_lse=True
) )
// Compute attention with the already existing context // Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)): for chunk_idx in range(cdiv(C, MCC)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment