Commit ec417173 authored by zhuwenwen's avatar zhuwenwen
Browse files

update test_prefix_prefill.py

parent cc4b902f
...@@ -20,7 +20,7 @@ CUDA_DEVICES = [ ...@@ -20,7 +20,7 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"] if not is_hip() else ["auto"]
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
...@@ -460,56 +460,8 @@ def test_contexted_kv_attention_alibi( ...@@ -460,56 +460,8 @@ def test_contexted_kv_attention_alibi(
...]) ...])
seq_start += seq_len seq_start += seq_len
query_start += query_len query_start += query_len
query = query_pad torch.cuda.synchronize()
end_time = time.time()
if num_kv_heads != num_heads: print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
# project the key and value tensors to the desired number of torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# heads. \ No newline at end of file
#
# see also: vllm/model_executor/layers/attention.py
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
query.shape[-1])
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], num_kv_heads,
num_queries_per_kv, value.shape[-1])
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output_ref = torch.empty_like(output)
seq_start = 0
query_start = 0
start_time = time.time()
# Attention with alibi slopes.
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
seq_end = seq_start + seq_len
query_end = query_start + query_len
out = xops.memory_efficient_attention_forward(query[:,
seq_start:seq_end],
key[:,
seq_start:seq_end],
value[:,
seq_start:seq_end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
out = out.view_as(query[:, seq_start:seq_end]).view(
seq_len, num_heads, head_size)
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
...])
seq_start += seq_len
query_start += query_len
torch.cuda.synchronize()
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment