Commit fc7db442 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface tests

parent aa389394
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import VLLM_PATH, RemoteOpenAIServer
import vllm.envs as envs
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists() assert chatml_jinja_path.exists()
...@@ -6,11 +7,25 @@ assert chatml_jinja_path.exists() ...@@ -6,11 +7,25 @@ assert chatml_jinja_path.exists()
def run_and_test_dummy_opt_api_server(model, tp=1): def run_and_test_dummy_opt_api_server(model, tp=1):
# the model is registered through the plugin # the model is registered through the plugin
if envs.VLLM_USE_TRITON_FLASH_ATTN:
server_args = [ server_args = [
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.10", "0.10",
"--dtype", "--dtype",
"float16",# "float32", "float32",
"--chat-template",
str(chatml_jinja_path),
"--load-format",
"dummy",
"-tp",
f"{tp}",
]
else:
server_args = [
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float16",
"--chat-template", "--chat-template",
str(chatml_jinja_path), str(chatml_jinja_path),
"--load-format", "--load-format",
...@@ -33,10 +48,11 @@ def run_and_test_dummy_opt_api_server(model, tp=1): ...@@ -33,10 +48,11 @@ def run_and_test_dummy_opt_api_server(model, tp=1):
) )
generated_text = completion.choices[0].message.content generated_text = completion.choices[0].message.content
assert generated_text is not None assert generated_text is not None
# make sure only the first token is generatedvim # make sure only the first token is generated
rest = generated_text.replace("<s>", "") rest = generated_text.replace("<s>", "")
assert rest == "" assert rest == ""
def test_oot_registration_for_api_server(dummy_opt_path: str): def test_oot_registration_for_api_server(dummy_opt_path: str):
dummy_opt_path="facebook/opt-125m"
run_and_test_dummy_opt_api_server(dummy_opt_path) run_and_test_dummy_opt_api_server(dummy_opt_path)
...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple ...@@ -3,7 +3,11 @@ from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
import vllm.attention.backends.flash_attn # noqa: F401 from vllm.utils import is_hip
if is_hip():
import flash_attn
else:
import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.utils import seed_everything from vllm.utils import seed_everything
...@@ -70,16 +74,16 @@ def ref_paged_attn( ...@@ -70,16 +74,16 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
if not is_hip():
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@torch.inference_mode() @torch.inference_mode()
def test_flash_attn_with_paged_kv( def test_flash_attn_with_paged_kv(
kv_lens: List[int], kv_lens: List[int],
num_heads: Tuple[int, int], num_heads: Tuple[int, int],
head_size: int, head_size: int,
...@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv( ...@@ -87,7 +91,7 @@ def test_flash_attn_with_paged_kv(
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
num_blocks: int, num_blocks: int,
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
seed_everything(0) seed_everything(0)
num_seqs = len(kv_lens) num_seqs = len(kv_lens)
...@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv( ...@@ -212,7 +216,22 @@ def test_varlen_with_paged_kv(
num_blocks, num_blocks,
(num_seqs, max_num_blocks_per_seq), (num_seqs, max_num_blocks_per_seq),
dtype=torch.int32) dtype=torch.int32)
if is_hip():
output = flash_attn.flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
else:
output = torch.ops.vllm.flash_attn_varlen_func( output = torch.ops.vllm.flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
...@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv( ...@@ -233,6 +252,7 @@ def test_varlen_with_paged_kv(
else: else:
test_utils = ["test_faketensor"] test_utils = ["test_faketensor"]
if not is_hip():
opcheck(torch.ops.vllm.flash_attn_varlen_func, opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(), args=tuple(),
kwargs=dict( kwargs=dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment