Commit fc7db442 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface tests

parent aa389394
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
import vllm.envs as envs
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists() assert chatml_jinja_path.exists()
...@@ -6,17 +7,31 @@ assert chatml_jinja_path.exists() ...@@ -6,17 +7,31 @@ assert chatml_jinja_path.exists()
def run_and_test_dummy_opt_api_server(model, tp=1): def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin # the model is registered through the plugin
server_args = [ if envs.VLLM_USE_TRITON_FLASH_ATTN:
"--gpu-memory-utilization", server_args = [
"0.10", "--gpu-memory-utilization",
"--dtype", "0.10",
"float16",# "float32", "--dtype",
"--chat-template", "float32",
str(chatml_jinja_path), "--chat-template",
"--load-format", str(chatml_jinja_path),
"dummy", "--load-format",
"-tp", "dummy",
f"{tp}", "-tp",
f"{tp}",
]
else:
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float16",
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
] ]
with RemoteOpenAIServer(model, server_args) as server: with RemoteOpenAIServer(model, server_args) as server:
client = server.get_client() client = server.get_client()
...@@ -33,10 +48,11 @@ def run_and_test_dummy_opt_api_server(model, tp=1): ...@@ -33,10 +48,11 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
) )
generated_text = completion.choices[0].message.content generated_text = completion.choices[0].message.content
assert generated_text is not None assert generated_text is not None
# make sure only the first token is generatedvim # make sure only the first token is generated
rest = generated_text.replace("<s>", "") rest = generated_text.replace("<s>", "")
assert rest == "" assert rest == ""
def test_oot_registration_for_api_server(dummy_opt_path: str): def test_oot_registration_for_api_server(dummy_opt_path: str):
dummy_opt_path="facebook/opt-125m"
run_and_test_dummy_opt_api_server(dummy_opt_path) run_and_test_dummy_opt_api_server(dummy_opt_path)
...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple ...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
import vllm.attention.backends.flash_attn # noqa: F401 from vllm.utils import is_hip
if is_hip():
import flash_attn
else:
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.utils import seed_everything from vllm.utils import seed_everything
...@@ -70,90 +74,90 @@ def ref_paged_attn( ...@@ -70,90 +74,90 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
if not is_hip():
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
@pytest.mark.parametrize("num_heads", NUM_HEADS) key_cache = torch.randn(num_blocks,
@pytest.mark.parametrize("head_size", HEAD_SIZES) block_size,
@pytest.mark.parametrize("block_size", BLOCK_SIZES) num_kv_heads,
@pytest.mark.parametrize("dtype", DTYPES) head_size,
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) dtype=dtype)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) value_cache = torch.randn_like(key_cache)
@torch.inference_mode() kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
key_cache = torch.randn(num_blocks, block_tables = torch.randint(0,
block_size, num_blocks,
num_kv_heads, (num_seqs, max_num_blocks_per_seq),
head_size, dtype=torch.int32)
dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size output = torch.ops.vllm.flash_attn_with_kvcache(
block_tables = torch.randint(0, decode_query=query.unsqueeze(1),
num_blocks, key_cache=key_cache,
(num_seqs, max_num_blocks_per_seq), value_cache=value_cache,
dtype=torch.int32) softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
output = torch.ops.vllm.flash_attn_with_kvcache( if num_blocks <= 2048:
decode_query=query.unsqueeze(1), test_utils = ["test_faketensor", "test_schema"]
key_cache=key_cache, else:
value_cache=value_cache, test_utils = ["test_faketensor"]
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]
opcheck(torch.ops.vllm.flash_attn_with_kvcache, opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(), args=tuple(),
kwargs=dict( kwargs=dict(
decode_query=query.unsqueeze(1), decode_query=query.unsqueeze(1),
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
block_table=block_tables, block_table=block_tables,
cache_seqlens=kv_lens_tensor, cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
), ),
test_utils=test_utils) test_utils=test_utils)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
query=query, query=query,
key_cache=key_cache, key_cache=key_cache,
value_cache=value_cache, value_cache=value_cache,
query_lens=[1] * num_seqs, query_lens=[1] * num_seqs,
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
...@@ -212,20 +216,35 @@ def test_varlen_with_paged_kv( ...@@ -212,20 +216,35 @@ def test_varlen_with_paged_kv(
num_blocks, num_blocks,
(num_seqs, max_num_blocks_per_seq), (num_seqs, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
if is_hip():
output = torch.ops.vllm.flash_attn_varlen_func( output = flash_attn.flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
cu_seqlens_q=cu_query_lens, cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens, cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len, max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len, max_seqlen_k=max_kv_len,
softmax_scale=scale, softmax_scale=scale,
causal=True, causal=True,
window_size=window_size, window_size=window_size,
block_table=block_tables, block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0, softcap=soft_cap if soft_cap is not None else 0,
)
else:
output = torch.ops.vllm.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
) )
if num_blocks <= 2048: if num_blocks <= 2048:
...@@ -233,23 +252,24 @@ def test_varlen_with_paged_kv( ...@@ -233,23 +252,24 @@ def test_varlen_with_paged_kv(
else: else:
test_utils = ["test_faketensor"] test_utils = ["test_faketensor"]
opcheck(torch.ops.vllm.flash_attn_varlen_func, if not is_hip():
args=tuple(), opcheck(torch.ops.vllm.flash_attn_varlen_func,
kwargs=dict( args=tuple(),
q=query, kwargs=dict(
k=key_cache, q=query,
v=value_cache, k=key_cache,
cu_seqlens_q=cu_query_lens, v=value_cache,
cu_seqlens_k=cu_kv_lens, cu_seqlens_q=cu_query_lens,
max_seqlen_q=max_query_len, cu_seqlens_k=cu_kv_lens,
max_seqlen_k=max_kv_len, max_seqlen_q=max_query_len,
softmax_scale=scale, max_seqlen_k=max_kv_len,
causal=True, softmax_scale=scale,
window_size=window_size, causal=True,
block_table=block_tables, window_size=window_size,
softcap=soft_cap if soft_cap is not None else 0, block_table=block_tables,
), softcap=soft_cap if soft_cap is not None else 0,
test_utils=test_utils) ),
test_utils=test_utils)
ref_output = ref_paged_attn( ref_output = ref_paged_attn(
query=query, query=query,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment