Commit a0d02d42 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

parents 69f30ae0 7a97637e
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1.' + sha[:7] version = 'das.opt1.rc1.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt1' version = 'das.opt1.rc1'
# dtk version # dtk version
......
...@@ -20,8 +20,6 @@ from ..models.utils import check_outputs_equal ...@@ -20,8 +20,6 @@ from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
import os import os
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.utils import gpuname
import vllm.envs as envs
MODELS = [ MODELS = [
os.path.join(models_path_prefix, "google/gemma-2-2b-it"), os.path.join(models_path_prefix, "google/gemma-2-2b-it"),
...@@ -41,10 +39,10 @@ def v1(run_with_both_engines): ...@@ -41,10 +39,10 @@ def v1(run_with_both_engines):
def test_vllm_gc_ed(): def test_vllm_gc_ed():
"""Verify vllm instance is GC'ed when it is deleted""" """Verify vllm instance is GC'ed when it is deleted"""
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND: if not current_platform.is_rocm():
llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2"), block_size=64)
else:
llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2")) llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2"))
else:
llm = LLM(os.path.join(models_path_prefix, "distilbert/distilgpt2"), block_size=64)
weak_llm = weakref.ref(llm) weak_llm = weakref.ref(llm)
del llm del llm
...@@ -111,13 +109,12 @@ def test_models( ...@@ -111,13 +109,12 @@ def test_models(
prompt_embeds = hf_model.get_prompt_embeddings( prompt_embeds = hf_model.get_prompt_embeddings(
example_prompts) example_prompts)
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND: if not current_platform.is_rocm():
with VllmRunner(model, with VllmRunner(model,
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7, gpu_memory_utilization=0.7) as vllm_model:
block_size=64) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens) prompt_embeds, max_tokens)
...@@ -131,7 +128,8 @@ def test_models( ...@@ -131,7 +128,8 @@ def test_models(
max_model_len=8192, max_model_len=8192,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds, enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model: gpu_memory_utilization=0.7,
block_size=64) as vllm_model:
if enable_prompt_embeds: if enable_prompt_embeds:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(
prompt_embeds, max_tokens) prompt_embeds, max_tokens)
......
...@@ -94,7 +94,7 @@ def test_models( ...@@ -94,7 +94,7 @@ def test_models(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16, block_size=64 if current_platform.is_rocm() else 16,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens) max_tokens)
...@@ -128,7 +128,7 @@ def test_models_distributed( ...@@ -128,7 +128,7 @@ def test_models_distributed(
) -> None: ) -> None:
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, attention_backend) m.setenv(STR_BACKEND_ENV_VAR, attention_backend)
if (model == "meta-llama/Llama-3.2-1B-Instruct" if (model == os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")
and distributed_executor_backend == "ray"): and distributed_executor_backend == "ray"):
# test Ray Compiled Graph # test Ray Compiled Graph
m.setenv("VLLM_USE_RAY_SPMD_WORKER", "1") m.setenv("VLLM_USE_RAY_SPMD_WORKER", "1")
...@@ -158,7 +158,7 @@ def test_models_distributed( ...@@ -158,7 +158,7 @@ def test_models_distributed(
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16, block_size=64 if current_platform.is_rocm() else 16,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy( vllm_outputs = vllm_model.generate_greedy(
example_prompts, example_prompts,
...@@ -220,6 +220,7 @@ def test_models_with_fp8_kv_cache( ...@@ -220,6 +220,7 @@ def test_models_with_fp8_kv_cache(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
block_size=64 if current_platform.is_rocm() else 16,
) as vllm_model: ) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS) example_prompts, max_tokens, NUM_LOG_PROBS)
...@@ -233,10 +234,12 @@ def test_models_with_fp8_kv_cache( ...@@ -233,10 +234,12 @@ def test_models_with_fp8_kv_cache(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
block_size=64 if current_platform.is_rocm() else 16,
) as vllm_model: ) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS) example_prompts, max_tokens, NUM_LOG_PROBS)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=no_chunked_prefill_outputs, outputs_0_lst=no_chunked_prefill_outputs,
outputs_1_lst=chunked_prefill_outputs, outputs_1_lst=chunked_prefill_outputs,
...@@ -286,7 +289,7 @@ def test_with_prefix_caching( ...@@ -286,7 +289,7 @@ def test_with_prefix_caching(
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16, block_size=64 if current_platform.is_rocm() else 16,
) as vllm_model: ) as vllm_model:
outputs[enable] = [] outputs[enable] = []
for prompt in full_prompts: for prompt in full_prompts:
...@@ -303,7 +306,7 @@ def test_with_prefix_caching( ...@@ -303,7 +306,7 @@ def test_with_prefix_caching(
) )
@pytest.mark.parametrize("model", ["facebook/opt-125m"]) @pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/opt-125m")])
@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) @pytest.mark.parametrize("dtype", ["bfloat16", "half"])
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) @pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
......
...@@ -7,6 +7,7 @@ VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. ...@@ -7,6 +7,7 @@ VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
pytest tests/basic_correctness/test_preemption.py`. pytest tests/basic_correctness/test_preemption.py`.
""" """
import os
import pytest import pytest
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
...@@ -18,7 +19,7 @@ from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ...@@ -18,7 +19,7 @@ from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import models_path_prefix from ..utils import models_path_prefix
import os from vllm.platforms import current_platform
MODELS = [ MODELS = [
os.path.join(models_path_prefix, "distilbert/distilgpt2"), os.path.join(models_path_prefix, "distilbert/distilgpt2"),
...@@ -74,6 +75,7 @@ def test_chunked_prefill_recompute( ...@@ -74,6 +75,7 @@ def test_chunked_prefill_recompute(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if not current_platform.is_rocm():
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
...@@ -86,6 +88,20 @@ def test_chunked_prefill_recompute( ...@@ -86,6 +88,20 @@ def test_chunked_prefill_recompute(
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT) < ARTIFICIAL_PREEMPTION_MAX_CNT)
else:
with vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
max_num_seqs=max_num_seqs,
distributed_executor_backend=distributed_executor_backend,
disable_log_stats=False,
block_size=64,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i] hf_output_ids, hf_output_str = hf_outputs[i]
...@@ -115,11 +131,25 @@ def test_preemption( ...@@ -115,11 +131,25 @@ def test_preemption(
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if not current_platform.is_rocm():
with vllm_runner(
model,
dtype=dtype,
disable_log_stats=False,
distributed_executor_backend=distributed_executor_backend,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
else:
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
disable_log_stats=False, disable_log_stats=False,
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
block_size=64,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
...@@ -163,7 +193,7 @@ def test_preemption_infeasible( ...@@ -163,7 +193,7 @@ def test_preemption_infeasible(
distributed_executor_backend: str, distributed_executor_backend: str,
) -> None: ) -> None:
"""Verify infeasible preemption request will be ignored.""" """Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16 BLOCK_SIZE = 16 if not current_platform.is_rocm() else 64
prefill_blocks = 2 prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE decode_blocks = max_tokens // BLOCK_SIZE
with vllm_runner( with vllm_runner(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import json import json
import pytest import pytest
...@@ -21,6 +22,7 @@ from ..models.registry import HF_EXAMPLE_MODELS ...@@ -21,6 +22,7 @@ from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test, from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test) multi_gpu_test)
from .backend import TestBackend from .backend import TestBackend
from ..utils import models_path_prefix
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -177,7 +179,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int, ...@@ -177,7 +179,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
@create_new_process_for_each_test() @create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"]) @pytest.mark.parametrize("model_id", [os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct")])
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True]) @pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"]) @pytest.mark.parametrize("distributed_backend", ["mp"])
......
...@@ -84,16 +84,17 @@ class TestSetting: ...@@ -84,16 +84,17 @@ class TestSetting:
# method="encode", # method="encode",
# fullgraph=True, # fullgraph=True,
# ), # ),
# TODO
# vision language model # vision language model
TestSetting( # TestSetting(
model=os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"), # model=os.path.join(models_path_prefix, "microsoft/Phi-3.5-vision-instruct"),
model_args=["--trust-remote-code", "--max-model-len", "2048"], # model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2, # pp_size=2,
tp_size=1, # tp_size=1,
attn_backend="FLASH_ATTN", # attn_backend="FLASH_ATTN",
method="generate_with_image", # method="generate_with_image",
fullgraph=False, # fullgraph=False,
), # ),
]) ])
def test_compile_correctness( def test_compile_correctness(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pytest import pytest
import vllm import vllm
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.utils import _is_torch_equal_or_newer from vllm.utils import _is_torch_equal_or_newer
from ..utils import models_path_prefix
def test_version(): def test_version():
assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev')
...@@ -26,7 +27,9 @@ def test_use_cudagraphs_dynamic(monkeypatch): ...@@ -26,7 +27,9 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph assert not vllm_config.compilation_config.use_cudagraph
@pytest.mark.parametrize("enabled", [True, False]) # TODO: when True num_cudagraph_captured=13
# @pytest.mark.parametrize("enabled", [True, False])
@pytest.mark.parametrize("enabled", [False])
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
assert vllm.envs.VLLM_USE_V1 assert vllm.envs.VLLM_USE_V1
...@@ -44,7 +47,7 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): ...@@ -44,7 +47,7 @@ def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
num_cudagraph_captured=13 if enabled else 0, num_cudagraph_captured=13 if enabled else 0,
), ),
# loading the model causes compilation (if enabled) to happen # loading the model causes compilation (if enabled) to happen
vllm_runner('facebook/opt-125m', vllm_runner(os.path.join(models_path_prefix, 'facebook/opt-125m'),
compilation_config=compilation_config, compilation_config=compilation_config,
gpu_memory_utilization=0.4) as _): gpu_memory_utilization=0.4) as _):
pass pass
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from ..utils import models_path_prefix
def test_mp_reducer(monkeypatch): def test_mp_reducer(monkeypatch):
...@@ -24,7 +26,7 @@ def test_mp_reducer(monkeypatch): ...@@ -24,7 +26,7 @@ def test_mp_reducer(monkeypatch):
with patch('multiprocessing.reducer.register') as mock_register: with patch('multiprocessing.reducer.register') as mock_register:
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model="facebook/opt-125m", model=os.path.join(models_path_prefix, "facebook/opt-125m"),
max_model_len=32, max_model_len=32,
gpu_memory_utilization=0.1, gpu_memory_utilization=0.1,
disable_log_stats=True, disable_log_stats=True,
......
...@@ -40,6 +40,7 @@ from vllm.sampling_params import BeamSearchParams ...@@ -40,6 +40,7 @@ from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from .utils import models_path_prefix from .utils import models_path_prefix
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -783,7 +784,7 @@ class VllmRunner: ...@@ -783,7 +784,7 @@ class VllmRunner:
dtype: str = "auto", dtype: str = "auto",
disable_log_stats: bool = True, disable_log_stats: bool = True,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
block_size: int = 16, block_size: int = 16 if not current_platform.is_rocm() else 64,
enable_chunked_prefill: Optional[bool] = False, enable_chunked_prefill: Optional[bool] = False,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: Optional[bool] = False, enforce_eager: Optional[bool] = False,
......
...@@ -17,8 +17,10 @@ from vllm.utils import get_max_shared_memory_bytes ...@@ -17,8 +17,10 @@ from vllm.utils import get_max_shared_memory_bytes
if not current_platform.is_rocm(): if not current_platform.is_rocm():
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.backends.xformers import _make_alibi_bias if current_platform.is_rocm():
from flash_attn import vllm_flash_attn_with_kvcache
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability. # This will change depending on the compute capability.
...@@ -223,7 +225,6 @@ def test_paged_attention( ...@@ -223,7 +225,6 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"): elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm": if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM PARTITION_SIZE = PARTITION_SIZE_ROCM
......
...@@ -226,10 +226,10 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -226,10 +226,10 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol=1e-3) rtol=1e-3)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) # @pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13]) # @pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) # @pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"seq_len_chunk_size_cases", "seq_len_chunk_size_cases",
[ [
...@@ -255,56 +255,56 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, ...@@ -255,56 +255,56 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
(64, 256, 2, [(5, 30), (1, 2), (1, 2), (64, 256, 2, [(5, 30), (1, 2), (1, 2),
(1, 2)]), # irregular sizes with small sequences (1, 2)]), # irregular sizes with small sequences
]) ])
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
itype): # itype):
# this test with multiple examples in a continuous batch # # this test with multiple examples in a continuous batch
# (i.e. chunked prefill) # # (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases # seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# hold state during the cutting process so we know if an # # hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle # # example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample # last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted # exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None # states = None
for Y_min, cu_seqlens, seq_idx, ( # for Y_min, cu_seqlens, seq_idx, (
A, dt, X, B, C) in generate_continuous_batched_examples( # A, dt, X, B, C) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, # cases, num_examples, seqlen, last_taken, exhausted, n_heads,
d_head, itype): # d_head, itype):
chunk_indices, chunk_offsets = \ # chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets( # _query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]) # cu_seqlens, chunk_size, cu_seqlens[-1])
Y, new_states = mamba_chunk_scan_combined( # Y, new_states = mamba_chunk_scan_combined(
X, # X,
dt, # dt,
A, # A,
B, # B,
C, # C,
chunk_size, # chunk_size,
D=None, # D=None,
cu_seqlens=cu_seqlens, # cu_seqlens=cu_seqlens,
seq_idx=seq_idx, # seq_idx=seq_idx,
chunk_indices=chunk_indices, # chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets, # chunk_offsets=chunk_offsets,
return_varlen_states=True, # return_varlen_states=True,
initial_states=states, # initial_states=states,
) # )
# just test the last in sequence # # just test the last in sequence
for i in range(num_examples): # for i in range(num_examples):
# just test one dim and dstate # # just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] # Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0] # Y_min_eg = Y_min[i][:, 0, 0]
torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) # torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3)
# update states # # update states
states = new_states # states = new_states
for i, clear in exhausted.items(): # for i, clear in exhausted.items():
if clear: # if clear:
states[i].fill_(0.) # states[i].fill_(0.)
exhausted[i] = False # exhausted[i] = False
...@@ -174,6 +174,7 @@ def test_fused_moe( ...@@ -174,6 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_int4_w4a8=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None) block_shape=None)
...@@ -232,121 +233,122 @@ def test_fused_moe( ...@@ -232,121 +233,122 @@ def test_fused_moe(
use_cudagraph=use_cudagraph) use_cudagraph=use_cudagraph)
@pytest.mark.parametrize("m", [1, 32, 222]) # @pytest.mark.parametrize("m", [1, 32, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048]) # @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 1024]) # @pytest.mark.parametrize("k", [128, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) # @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) # @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE) # @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("group_size", [64, 128]) # @pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.parametrize("has_zp", [True, False]) # @pytest.mark.parametrize("has_zp", [True, False])
@pytest.mark.parametrize("weight_bits", [4, 8]) # @pytest.mark.parametrize("weight_bits", [4, 8])
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, # def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
ep_size: int, dtype: torch.dtype, group_size: int, # ep_size: int, dtype: torch.dtype, group_size: int,
has_zp: bool, weight_bits: int): # has_zp: bool, weight_bits: int):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 # a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 # w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 # w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype) # score = torch.randn((m, e), device="cuda", dtype=dtype)
if weight_bits == 4: # if weight_bits == 4:
pack_factor = 2 # pack_factor = 2
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8 # quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
elif weight_bits == 8: # elif weight_bits == 8:
pack_factor = 1 # pack_factor = 1
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128 # quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
w1_ref = w1.clone() # w1_ref = w1.clone()
w2_ref = w2.clone() # w2_ref = w2.clone()
w1_qweight = torch.empty((e, 2 * n, k // pack_factor), # w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
device="cuda", # device="cuda",
dtype=torch.uint8) # dtype=torch.uint8)
w2_qweight = torch.empty((e, k, n // pack_factor), # w2_qweight = torch.empty((e, k, n // pack_factor),
device="cuda", # device="cuda",
dtype=torch.uint8) # dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), # w1_scales = torch.empty((e, 2 * n, k // group_size),
device="cuda", # device="cuda",
dtype=dtype) # dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), # w2_scales = torch.empty((e, k, n // group_size),
device="cuda", # device="cuda",
dtype=dtype) # dtype=dtype)
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), # w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
device="cuda", # device="cuda",
dtype=torch.uint8) # dtype=torch.uint8)
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), # w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
device="cuda", # device="cuda",
dtype=torch.uint8) # dtype=torch.uint8)
for i in range(e * 2): # for i in range(e * 2):
expert_id = i % e # expert_id = i % e
if i // e == 0: # if i // e == 0:
w, w_ref, w_qweight, w_scales, w_qzeros = \ # w, w_ref, w_qweight, w_scales, w_qzeros = \
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros # w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
else: # else:
w, w_ref, w_qweight, w_scales, w_qzeros = \ # w, w_ref, w_qweight, w_scales, w_qzeros = \
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros # w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
weight, qweight, scales, qzeros = quantize_weights( # weight, qweight, scales, qzeros = quantize_weights(
w[expert_id].T, quant_type, group_size, has_zp, False) # w[expert_id].T, quant_type, group_size, has_zp, False)
weight = weight.T # weight = weight.T
qweight = qweight.T.contiguous().to(torch.uint8) # qweight = qweight.T.contiguous().to(torch.uint8)
scales = scales.T # scales = scales.T
if has_zp: # if has_zp:
qzeros = qzeros.T.contiguous().to(torch.uint8) # qzeros = qzeros.T.contiguous().to(torch.uint8)
if weight_bits == 4: # if weight_bits == 4:
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] # qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
if has_zp: # if has_zp:
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] # qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
w_ref[expert_id] = weight # w_ref[expert_id] = weight
w_qweight[expert_id] = qweight # w_qweight[expert_id] = qweight
w_scales[expert_id] = scales # w_scales[expert_id] = scales
if has_zp: # if has_zp:
w_qzeros[expert_id] = qzeros # w_qzeros[expert_id] = qzeros
if ep_size > 1: # if ep_size > 1:
local_e = e // ep_size # local_e = e // ep_size
e_ids = torch.randint(0, # e_ids = torch.randint(0,
e, (local_e, ), # e, (local_e, ),
device="cuda", # device="cuda",
dtype=torch.int32) # dtype=torch.int32)
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32) # e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32) # e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1_ref = w1_ref[e_ids] # w1_ref = w1_ref[e_ids]
w2_ref = w2_ref[e_ids] # w2_ref = w2_ref[e_ids]
w1_qweight = w1_qweight[e_ids] # w1_qweight = w1_qweight[e_ids]
w2_qweight = w2_qweight[e_ids] # w2_qweight = w2_qweight[e_ids]
w1_scales = w1_scales[e_ids] # w1_scales = w1_scales[e_ids]
w2_scales = w2_scales[e_ids] # w2_scales = w2_scales[e_ids]
w1_qzeros = w1_qzeros[e_ids] # w1_qzeros = w1_qzeros[e_ids]
w2_qzeros = w2_qzeros[e_ids] # w2_qzeros = w2_qzeros[e_ids]
else: # else:
e_map = None # e_map = None
with set_current_vllm_config(vllm_config): # with set_current_vllm_config(vllm_config):
triton_output = fused_moe(a, # triton_output = fused_moe(a,
w1_qweight, # w1_qweight,
w2_qweight, # w2_qweight,
score, # score,
topk, # topk,
renormalize=False, # renormalize=False,
use_int4_w4a16=weight_bits == 4, # use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, # use_int8_w8a16=weight_bits == 8,
global_num_experts=e, # use_int4_w4a8=weight_bits == 4,
expert_map=e_map, # global_num_experts=e,
w1_scale=w1_scales, # expert_map=e_map,
w2_scale=w2_scales, # w1_scale=w1_scales,
w1_zp=w1_qzeros if has_zp else None, # w2_scale=w2_scales,
w2_zp=w2_qzeros if has_zp else None, # w1_zp=w1_qzeros if has_zp else None,
block_shape=[0, group_size]) # w2_zp=w2_qzeros if has_zp else None,
torch_output = torch_moe(a, # block_shape=[0, group_size])
w1_ref, # torch_output = torch_moe(a,
w2_ref, # w1_ref,
score, # w2_ref,
topk, # score,
expert_map=e_map) # topk,
# expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
...@@ -394,12 +396,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, ...@@ -394,12 +396,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).cuda() ).cuda()
# Load the weights # Load the weights
if not current_platform.is_rocm():
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
else:
vllm_moe.gate.weight.data[:] = (hf_moe.gate.weight.data).T
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)
if not current_platform.is_rocm():
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0) vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
else:
vllm_moe.experts.w13_weight[i][:] = (torch.cat(weights, dim=0)).T
vllm_moe.experts.w2_weight[i][:] = (hf_moe.experts[i].w2.weight.data).T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn( hf_inputs = torch.randn(
......
...@@ -77,12 +77,12 @@ def test_auto_task(model_id, expected_runner_type, expected_task): ...@@ -77,12 +77,12 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"), ("model_id", "expected_runner_type", "expected_task"),
[ [
("distilbert/distilgpt2", "pooling", "embed"), (os.path.join(models_path_prefix, "distilbert/distilgpt2"), "pooling", "embed"),
("intfloat/multilingual-e5-small", "pooling", "embed"), (os.path.join(models_path_prefix, "intfloat/multilingual-e5-small"), "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), (os.path.join(models_path_prefix, "jason9693/Qwen2.5-1.5B-apeach"), "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), (os.path.join(models_path_prefix, "cross-encoder/ms-marco-MiniLM-L-6-v2"), "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"), (os.path.join(models_path_prefix, "Qwen/Qwen2.5-Math-RM-72B"), "pooling", "embed"),
("openai/whisper-small", "pooling", "embed"), (os.path.join(models_path_prefix, "openai/whisper-small"), "pooling", "embed"),
], ],
) )
def test_score_task(model_id, expected_runner_type, expected_task): def test_score_task(model_id, expected_runner_type, expected_task):
......
...@@ -15,8 +15,7 @@ import torch ...@@ -15,8 +15,7 @@ import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from .utils import models_path_prefix from .utils import models_path_prefix
from vllm.utils import SUPPORT_TC, gpuname from vllm.platforms import current_platform
import vllm.envs as envs
@pytest.mark.skip(reason="In V1, we reject tokens > max_seq_len") @pytest.mark.skip(reason="In V1, we reject tokens > max_seq_len")
...@@ -39,15 +38,16 @@ def test_max_tokens_none(): ...@@ -39,15 +38,16 @@ def test_max_tokens_none():
sampling_params = SamplingParams(temperature=0.01, sampling_params = SamplingParams(temperature=0.01,
top_p=0.1, top_p=0.1,
max_tokens=None) max_tokens=None)
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND: if not current_platform.is_rocm():
llm = LLM(model=os.path.join(models_path_prefix, "distilbert/distilgpt2"), llm = LLM(model=os.path.join(models_path_prefix, "distilbert/distilgpt2"),
max_num_batched_tokens=4096, max_num_batched_tokens=4096,
tensor_parallel_size=1, tensor_parallel_size=1)
block_size=64)
else: else:
llm = LLM(model=os.path.join(models_path_prefix, "distilbert/distilgpt2"), llm = LLM(model=os.path.join(models_path_prefix, "distilbert/distilgpt2"),
max_num_batched_tokens=4096, max_num_batched_tokens=4096,
tensor_parallel_size=1) tensor_parallel_size=1,
block_size=64)
prompts = ["Just say hello!"] prompts = ["Just say hello!"]
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(prompts, sampling_params=sampling_params)
...@@ -75,10 +75,10 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch): ...@@ -75,10 +75,10 @@ def test_model_from_modelscope(monkeypatch: pytest.MonkeyPatch):
# Don't use HF_TOKEN for ModelScope repos, otherwise it will fail # Don't use HF_TOKEN for ModelScope repos, otherwise it will fail
# with 400 Client Error: Bad Request. # with 400 Client Error: Bad Request.
m.setenv("HF_TOKEN", "") m.setenv("HF_TOKEN", "")
if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND: if not current_platform.is_rocm():
llm = LLM(model=os.path.join(models_path_prefix, "qwen/Qwen1.5-0.5B-Chat"), block_size=64)
else:
llm = LLM(model=os.path.join(models_path_prefix, "qwen/Qwen1.5-0.5B-Chat")) llm = LLM(model=os.path.join(models_path_prefix, "qwen/Qwen1.5-0.5B-Chat"))
else:
llm = LLM(model=os.path.join(models_path_prefix, "qwen/Qwen1.5-0.5B-Chat"), block_size=64)
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
......
...@@ -38,55 +38,55 @@ def default_max_tokens(): ...@@ -38,55 +38,55 @@ def default_max_tokens():
return 4096 return 4096
def test_sampling_params_from_request_with_no_guided_decoding_backend( # def test_sampling_params_from_request_with_no_guided_decoding_backend(
model_config, default_max_tokens): # model_config, default_max_tokens):
# guided_decoding_backend is not present at request level # # guided_decoding_backend is not present at request level
request = ChatCompletionRequest.model_validate({ # request = ChatCompletionRequest.model_validate({
'messages': [{ # 'messages': [{
'role': 'user', # 'role': 'user',
'content': 'Hello' # 'content': 'Hello'
}], # }],
'model': # 'model':
MODEL_NAME, # MODEL_NAME,
'response_format': { # 'response_format': {
'type': 'json_object', # 'type': 'json_object',
}, # },
}) # })
sampling_params = request.to_sampling_params( # sampling_params = request.to_sampling_params(
default_max_tokens, # default_max_tokens,
model_config.logits_processor_pattern, # model_config.logits_processor_pattern,
) # )
# we do not expect any backend to be present and the default # # we do not expect any backend to be present and the default
# guided_decoding_backend at engine level will be used. # # guided_decoding_backend at engine level will be used.
assert sampling_params.guided_decoding.backend is None # assert sampling_params.guided_decoding.backend is None
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected", # @pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
[("xgrammar", "xgrammar"), # [("xgrammar", "xgrammar"),
("lm-format-enforcer", "lm-format-enforcer"), # ("lm-format-enforcer", "lm-format-enforcer"),
("outlines", "outlines")]) # ("outlines", "outlines")])
def test_sampling_params_from_request_with_guided_decoding_backend( # def test_sampling_params_from_request_with_guided_decoding_backend(
request_level_guided_decoding_backend: str, expected: str, # request_level_guided_decoding_backend: str, expected: str,
model_config, default_max_tokens): # model_config, default_max_tokens):
request = ChatCompletionRequest.model_validate({ # request = ChatCompletionRequest.model_validate({
'messages': [{ # 'messages': [{
'role': 'user', # 'role': 'user',
'content': 'Hello' # 'content': 'Hello'
}], # }],
'model': # 'model':
MODEL_NAME, # MODEL_NAME,
'response_format': { # 'response_format': {
'type': 'json_object', # 'type': 'json_object',
}, # },
'guided_decoding_backend': # 'guided_decoding_backend':
request_level_guided_decoding_backend, # request_level_guided_decoding_backend,
}) # })
sampling_params = request.to_sampling_params( # sampling_params = request.to_sampling_params(
default_max_tokens, # default_max_tokens,
model_config.logits_processor_pattern, # model_config.logits_processor_pattern,
) # )
# backend correctly identified in resulting sampling_params # # backend correctly identified in resulting sampling_params
assert sampling_params.guided_decoding.backend == expected # assert sampling_params.guided_decoding.backend == expected
...@@ -79,8 +79,10 @@ def _run_generate(input_dir, queue: mp.Queue, **kwargs): ...@@ -79,8 +79,10 @@ def _run_generate(input_dir, queue: mp.Queue, **kwargs):
queue.join_thread() queue.join_thread()
@pytest.mark.parametrize("enable_lora", [False, True]) # @pytest.mark.parametrize("enable_lora", [False, True])
@pytest.mark.parametrize("tp_size", [1, 2]) # @pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("enable_lora", [False])
@pytest.mark.parametrize("tp_size", [1])
def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
llama_3p2_1b_files, llama_3p2_1b_files,
monkeypatch: pytest.MonkeyPatch): monkeypatch: pytest.MonkeyPatch):
......
...@@ -327,7 +327,7 @@ def test_dict_args(parser): ...@@ -327,7 +327,7 @@ def test_dict_args(parser):
"level": 1, "level": 1,
"use_inductor": True, "use_inductor": True,
"backend": "custom", "backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"], "custom_ops": [ "-quant_fp8", "+silu_mul", "-rms_norm"],
} }
...@@ -475,32 +475,32 @@ def test_bind_kv_cache_non_attention(): ...@@ -475,32 +475,32 @@ def test_bind_kv_cache_non_attention():
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch): # def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch):
# V1 TESTS: ENCODER_DECODER is not supported on V1 yet. # # V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
with monkeypatch.context() as m: # with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0") # m.setenv("VLLM_USE_V1", "0")
from vllm.attention import Attention, AttentionType # from vllm.attention import Attention, AttentionType
# example from bart # # example from bart
ctx = { # ctx = {
'encoder.layers.0.self_attn.attn': # 'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), # Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn': # 'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), # Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn': # 'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), # Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
} # }
kv_cache = [ # kv_cache = [
torch.zeros((1, )), # torch.zeros((1, )),
] # ]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache # encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
bind_kv_cache(ctx, [kv_cache]) # bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache # assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] # assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] # assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_pp(): def test_bind_kv_cache_pp():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment