Unverified Commit 10cc12ba authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Feature/mla tests (#23195)


Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent a4fbb32f
......@@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache(
# Permute the context blocks (excluding block 0 which is null)
if randomize_blocks:
perm = torch.randperm(
blocks_end - 1) + 1 # Random permutation starting from block 1
# Random permutation starting from block 1
perm = torch.randperm(blocks_end - 1) + 1
else:
perm = torch.arange(
1, blocks_end) # Sequential order starting from block 1
# Sequential order starting from block 1
perm = torch.arange(1, blocks_end)
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
inv_perm[1:] = torch.argsort(
perm) + 1 # Add 1 to account for starting from block 1
# Add 1 to account for starting from block 1
inv_perm[1:] = torch.argsort(perm) + 1
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
# Construct the right block table
......@@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium"
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str):
......@@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
"""
batch_spec = BATCH_SPECS[batch_spec_name]
vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens))
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=8192)
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
......@@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol=rtol,
atol=atol)
if not all_close:
print(f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
print(f"[{backend_name}] output: {backend_output}")
print(f"[{backend_name}] SDPA baseline: {sdpa_output}")
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
\ No newline at end of file
This diff is collapsed.
......@@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
if backend_name not in backend_map:
......@@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1,
max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto",
num_gpu_blocks: int = 1000,
block_size: int = 16,
max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults."""
......@@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
)
# Set cache blocks for testing
# (these may be set during initialization normally)
cache_config.num_gpu_blocks = 1000
cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig(
......@@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
)
device_config = DeviceConfig()
......
......@@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
......@@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
)
return spda_o @ W_O
NOTE: in the actual code,
......@@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
......@@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
new_v,
casual=True,
return_softmax_lse=True
)
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment