Commit e661d594 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1

parents 6b16ea2e 4db5176d
...@@ -2,6 +2,5 @@ ...@@ -2,6 +2,5 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for x86_64 CPUs # Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le" torch == 2.4.0+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# Dependencies for NVIDIA GPUs # Dependencies for NVIDIA GPUs
ray >= 2.9 ray >= 2.9
nvidia-ml-py # for pynvml package nvidia-ml-py # for pynvml package
torch == 2.3.1 torch == 2.4.0
# These must be updated alongside torch # These must be updated alongside torch
torchvision == 0.18.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27 # Requires PyTorch 2.3.1 xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
vllm-flash-attn == 2.5.9.post1 # Requires PyTorch 2.3.1 vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0
# Common dependencies # Common dependencies
-r requirements-common.txt # -r requirements-common.txt
# TODO: remove temporary copy of all common dependencies once Optimum Intel will support Transformers >= 4.43.2
cmake >= 3.21
ninja # For faster builds.
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers < 4.43
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
aiohttp
openai
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
# OpenVINO dependencies # OpenVINO dependencies
torch >= 2.1.2 torch >= 2.1.2
openvino ~= 2024.3.0.dev openvino ~= 2024.3.0.dev
openvino-tokenizers[transformers] ~= 2024.3.0.0.dev
optimum-intel[openvino] >= 1.18.1 optimum-intel[openvino] >= 1.18.1
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
# Needed for Ray accelerated DAG tests
-r requirements-adag.txt
# testing # testing
pytest pytest
tensorizer>=2.9.0 tensorizer>=2.9.0
...@@ -14,8 +17,8 @@ peft ...@@ -14,8 +17,8 @@ peft
requests requests
ray ray
sentence-transformers # required for embedding sentence-transformers # required for embedding
sparseml==1.8.0 # required for compressed-tensors
compressed-tensors==0.4.0 # required for compressed-tensors compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
# Benchmarking # Benchmarking
aiohttp aiohttp
......
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# Dependencies for TPU # Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA. # Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu. # You can install the dependencies in Dockerfile.tpu.
triton # To avoid import errors ray
...@@ -188,9 +188,6 @@ class cmake_build_ext(build_ext): ...@@ -188,9 +188,6 @@ class cmake_build_ext(build_ext):
# match. # match.
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)] cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
if _install_punica():
cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON']
# #
# Setup parallelism and build tool # Setup parallelism and build tool
# #
...@@ -281,8 +278,8 @@ def _build_custom_ops() -> bool: ...@@ -281,8 +278,8 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu() return _is_cuda() or _is_hip() or _is_cpu()
def _install_punica() -> bool: def _build_core_ext() -> bool:
return envs.VLLM_INSTALL_PUNICA_KERNELS return not _is_neuron() and not _is_tpu()
def get_hipcc_rocm_version(): def get_hipcc_rocm_version():
...@@ -388,19 +385,20 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -388,19 +385,20 @@ def get_version_add(sha: Optional[str] = None) -> str:
version += ".torch" + torch.__version__[:5] version += ".torch" + torch.__version__[:5]
new_version_content = f""" new_version_content = f"""
import warnings import warnings
try: try:
import vllm.commit_id import vllm.commit_id
__commit__ = vllm.commit_id.__commit__ __commit__ = vllm.commit_id.__commit__
except Exception as e: except Exception as e:
warnings.warn(f"Failed to read commit hash:\\n + str(e)", warnings.warn(f"Failed to read commit hash:\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER" __commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.3.post1" __version__ = "0.5.4"
__dcu_version__ = f'0.5.3.post1+{version}' __dcu_version__ = f'0.5.4+{version}'
""" """
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
...@@ -507,15 +505,15 @@ def get_requirements() -> List[str]: ...@@ -507,15 +505,15 @@ def get_requirements() -> List[str]:
ext_modules = [] ext_modules = []
if _build_core_ext():
ext_modules.append(CMakeExtension(name="vllm._core_C"))
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _build_custom_ops(): if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
if _install_punica():
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
package_data = { package_data = {
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
} }
...@@ -542,6 +540,7 @@ setup( ...@@ -542,6 +540,7 @@ setup(
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License", "License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
...@@ -553,7 +552,7 @@ setup( ...@@ -553,7 +552,7 @@ setup(
extras_require={ extras_require={
"tensorizer": ["tensorizer>=2.9.0"], "tensorizer": ["tensorizer>=2.9.0"],
}, },
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
package_data=package_data, package_data=package_data,
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
......
from vllm.utils import is_hip import pytest
from tests.quantization.utils import is_quant_method_supported
from ..utils import compare_two_settings from ..utils import compare_two_settings
...@@ -6,8 +8,37 @@ from ..utils import compare_two_settings ...@@ -6,8 +8,37 @@ from ..utils import compare_two_settings
def test_cpu_offload(): def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [], compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"]) ["--cpu-offload-gb", "4"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings( @pytest.mark.skipif(not is_quant_method_supported("fp8"),
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [], reason="fp8 is not supported on this GPU type.")
["--cpu-offload-gb", "1"]) def test_cpu_offload_fp8():
# Test quantization of an unquantized checkpoint
compare_two_settings("meta-llama/Meta-Llama-3-8B-Instruct",
["--quantization", "fp8"],
["--quantization", "fp8", "--cpu-offload-gb", "2"])
# Test loading a quantized checkpoint
compare_two_settings("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("awq"),
reason="awq is not supported on this GPU type.")
def test_cpu_offload_awq():
compare_two_settings("casperhansen/llama-3-8b-instruct-awq", [],
["--cpu-offload-gb", "2"])
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="gptq_marlin is not supported on this GPU type.")
def test_cpu_offload_compressed_tensors():
# Test wNa16
compare_two_settings("nm-testing/tinyllama-oneshot-w4a16-channel-v2", [],
["--cpu-offload-gb", "1"])
# Test w4a16_marlin24
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
# Test w8a8
compare_two_settings(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", [],
["--cpu-offload-gb", "1"])
...@@ -3,7 +3,7 @@ import gc ...@@ -3,7 +3,7 @@ import gc
import os import os
import sys import sys
from collections import UserList from collections import UserList
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
import pytest import pytest
import torch import torch
...@@ -11,7 +11,7 @@ import torch.nn as nn ...@@ -11,7 +11,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding) AutoTokenizer, BatchEncoding, BatchFeature)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
...@@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets: ...@@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS return IMAGE_ASSETS
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding) _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
class HfRunner: class HfRunner:
...@@ -152,7 +152,6 @@ class HfRunner: ...@@ -152,7 +152,6 @@ class HfRunner:
model_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False, is_embedding_model: bool = False,
is_vision_model: bool = False, is_vision_model: bool = False,
is_sparseml_model: bool = False,
) -> None: ) -> None:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
...@@ -169,9 +168,6 @@ class HfRunner: ...@@ -169,9 +168,6 @@ class HfRunner:
else: else:
if is_vision_model: if is_vision_model:
auto_cls = AutoModelForVision2Seq auto_cls = AutoModelForVision2Seq
elif is_sparseml_model:
from sparseml.transformers import SparseAutoModelForCausalLM
auto_cls = SparseAutoModelForCausalLM
else: else:
auto_cls = AutoModelForCausalLM auto_cls = AutoModelForCausalLM
...@@ -339,7 +335,6 @@ class HfRunner: ...@@ -339,7 +335,6 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids
output = self.model.generate( output = self.model.generate(
**self.wrap_device(inputs), **self.wrap_device(inputs),
...@@ -381,7 +376,7 @@ class HfRunner: ...@@ -381,7 +376,7 @@ class HfRunner:
all_logprobs.append(seq_logprobs_lst) all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0] seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1] output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:] output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist()) all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids)) all_output_strs.append(self.tokenizer.decode(output_ids))
...@@ -513,11 +508,14 @@ class VllmRunner: ...@@ -513,11 +508,14 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
images: Optional[List[Image.Image]] = None, images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0, greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs) logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts, outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params, greedy_logprobs_params,
images=images) images=images)
......
...@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, ...@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case. # Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size # Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 1), "num_gpu_blocks_override": 2 * (16 + 2),
} }
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{ @pytest.mark.parametrize("baseline_llm_kwargs", [{
......
import time import time
from collections import deque from collections import deque
from typing import Deque, List, Set, Tuple from typing import List, Set, Tuple
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # noqa import pytest # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus from vllm.core.interfaces import AllocStatus
from vllm.core.policy import PolicyFactory
from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, SequenceGroup, SequenceStatus from vllm.sequence import Logprob, SequenceGroup, SequenceStatus
...@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len(): ...@@ -348,10 +347,10 @@ def test_prefill_schedule_max_prompt_len():
""" """
scheduler = initialize_scheduler(max_model_len=30) scheduler = initialize_scheduler(max_model_len=30)
_, seq_group = create_dummy_prompt("0", prompt_length=60) _, seq_group = create_dummy_prompt("0", prompt_length=60)
waiting = deque([seq_group]) scheduler.add_seq_group(seq_group)
budget = create_token_budget() budget = create_token_budget()
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 1 assert len(output.ignored_seq_groups) == 1
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
...@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget(): ...@@ -364,15 +363,14 @@ def test_prefill_schedule_token_budget():
Test token budget respected. Test token budget respected.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=0) budget = create_token_budget(token_budget=0)
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
# 0 token budget == nothing is scheduled. # 0 token budget == nothing is scheduled.
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
...@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget(): ...@@ -381,8 +379,8 @@ def test_prefill_schedule_token_budget():
# 60 token budget == 1 request scheduled. # 60 token budget == 1 request scheduled.
budget = create_token_budget(token_budget=60) budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 1 assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 60 assert budget.num_batched_tokens == 60
...@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget(): ...@@ -391,14 +389,13 @@ def test_prefill_schedule_token_budget():
# Test when current_batched_tokens respected. # Test when current_batched_tokens respected.
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget(token_budget=60) budget = create_token_budget(token_budget=60)
add_token_budget(budget, 30, 0) add_token_budget(budget, 30, 0)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
# Cannot schedule a prompt that doesn't fit the budget. # Cannot schedule a prompt that doesn't fit the budget.
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 30 assert budget.num_batched_tokens == 30
...@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget(): ...@@ -406,8 +403,8 @@ def test_prefill_schedule_token_budget():
assert len(remaining_waiting) == 1 assert len(remaining_waiting) == 1
budget = create_token_budget(token_budget=90) budget = create_token_budget(token_budget=90)
add_token_budget(budget, 30, 0) add_token_budget(budget, 30, 0)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.seq_groups) == 1 assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 90 assert budget.num_batched_tokens == 90
assert budget.num_curr_seqs == 1 assert budget.num_curr_seqs == 1
...@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs(): ...@@ -419,13 +416,12 @@ def test_prefill_schedule_max_seqs():
Test max seq respected. Test max seq respected.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(max_num_seqs=2) budget = create_token_budget(max_num_seqs=2)
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2 assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120 assert budget.num_batched_tokens == 120
...@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs(): ...@@ -433,13 +429,13 @@ def test_prefill_schedule_max_seqs():
assert len(remaining_waiting) == 1 assert len(remaining_waiting) == 1
# Verify curr_num_seqs respected. # Verify curr_num_seqs respected.
waiting = deque() scheduler.waiting = deque()
budget = create_token_budget(max_num_seqs=2) budget = create_token_budget(max_num_seqs=2)
add_token_budget(budget, 0, 2) add_token_budget(budget, 0, 2)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
...@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora(): ...@@ -453,7 +449,6 @@ def test_prefill_schedule_max_lora():
""" """
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config) scheduler = initialize_scheduler(lora_config=lora_config)
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=120) budget = create_token_budget(token_budget=120)
curr_loras: Set[int] = set() curr_loras: Set[int] = set()
for i in range(2): for i in range(2):
...@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora(): ...@@ -463,7 +458,7 @@ def test_prefill_schedule_max_lora():
lora_name=str(i), lora_name=str(i),
lora_int_id=i + 1, lora_int_id=i + 1,
lora_path="abc")) lora_path="abc"))
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
# Add two more requests to verify lora is prioritized. # Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular # 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled. # In the first iteration, index 0, 2 is scheduled.
...@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora(): ...@@ -471,10 +466,10 @@ def test_prefill_schedule_max_lora():
# prioritized. Verify that. # prioritized. Verify that.
for i in range(2, 4): for i in range(2, 4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
# Schedule 2 requests (0 and 2) # Schedule 2 requests (0 and 2)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, curr_loras)
waiting, budget, curr_loras) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2 assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120 assert budget.num_batched_tokens == 120
...@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora(): ...@@ -485,8 +480,8 @@ def test_prefill_schedule_max_lora():
# Reset curr_loras so that it can be scheduled. # Reset curr_loras so that it can be scheduled.
curr_loras = set() curr_loras = set()
budget = create_token_budget(token_budget=60) budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, curr_loras)
remaining_waiting, budget, curr_loras) remaining_waiting = scheduler.waiting
assert len(output.seq_groups) == 1 assert len(output.seq_groups) == 1
assert output.seq_groups[0].seq_group.request_id == "1" assert output.seq_groups[0].seq_group.request_id == "1"
assert len(remaining_waiting) == 1 assert len(remaining_waiting) == 1
...@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity(): ...@@ -499,31 +494,29 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity. Test sequence cannot be scheduled due to block manager has no capacity.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget() budget = create_token_budget()
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
remainig_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 0 assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0 assert budget.num_curr_seqs == 0
assert len(remainig_waiting) == 3 assert len(remaining_waiting) == 3
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget() budget = create_token_budget()
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group) scheduler.add_seq_group(seq_group)
scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
remaining_waiting, output = scheduler._schedule_prefills( output = scheduler._schedule_prefills(budget, None)
waiting, budget, None) remaining_waiting = scheduler.waiting
assert len(output.ignored_seq_groups) == 3 assert len(output.ignored_seq_groups) == 3
assert len(output.seq_groups) == 0 assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
...@@ -536,14 +529,12 @@ def test_decode_schedule_preempted(): ...@@ -536,14 +529,12 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted. Test decodes cannot be scheduled and preempted.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group) scheduler._add_seq_group_to_running(seq_group)
scheduler.block_manager.can_append_slots = MagicMock() scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots): def cannot_append_second_group(seq_group, num_lookahead_slots):
...@@ -555,8 +546,8 @@ def test_decode_schedule_preempted(): ...@@ -555,8 +546,8 @@ def test_decode_schedule_preempted():
# 1 cannot be scheduled, and the lowest priority (request 2) # 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted. # should be preempted. 1 will also be preempted.
budget = create_token_budget() budget = create_token_budget()
remainig_running, output = scheduler._schedule_running( output = scheduler._schedule_running(budget, curr_loras)
running, budget, curr_loras, policy) remainig_running = scheduler.running
assert len(remainig_running) == 0 assert len(remainig_running) == 0
assert len(output.decode_seq_groups) == 1 assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
...@@ -577,14 +568,12 @@ def test_decode_swap_beam_search(): ...@@ -577,14 +568,12 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks Test best_of > 1 swap out blocks
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
budget = create_token_budget() budget = create_token_budget()
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
running.append(seq_group) scheduler._add_seq_group_to_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
budget.add_num_seqs(seq_group.request_id, budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs()) seq_group.get_max_num_running_seqs())
...@@ -603,8 +592,8 @@ def test_decode_swap_beam_search(): ...@@ -603,8 +592,8 @@ def test_decode_swap_beam_search():
expected_swap_mapping = [("5", "7")] expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running( output = scheduler._schedule_running(budget, curr_loras)
running, budget, curr_loras, policy) remainig_running = scheduler.running
assert len(remainig_running) == 0 assert len(remainig_running) == 0
assert len(output.decode_seq_groups) == 2 assert len(output.decode_seq_groups) == 2
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
...@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update(): ...@@ -628,20 +617,18 @@ def test_schedule_decode_blocks_to_copy_update():
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
running.append(seq_group) scheduler._add_seq_group_to_running(seq_group)
# The last request should be swapped out. # The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = [(2, 3)] scheduler.block_manager.append_slots.return_value = [(2, 3)]
budget = create_token_budget() budget = create_token_budget()
remaining_running, output = scheduler._schedule_running( output = scheduler._schedule_running(budget, curr_loras)
running, budget, curr_loras, policy) remaining_running = scheduler.running
assert len(remaining_running) == 0 assert len(remaining_running) == 0
assert len(output.decode_seq_groups) == 1 assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
...@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update(): ...@@ -656,19 +643,17 @@ def test_schedule_decode_blocks_to_copy_update():
def test_schedule_swapped_simple(): def test_schedule_swapped_simple():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0 assert len(remaining_swapped) == 0
assert budget.num_batched_tokens == 1 assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2 assert budget.num_curr_seqs == 2
...@@ -683,8 +668,6 @@ def test_schedule_swapped_simple(): ...@@ -683,8 +668,6 @@ def test_schedule_swapped_simple():
def test_schedule_swapped_max_token_budget(): def test_schedule_swapped_max_token_budget():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
...@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget(): ...@@ -692,11 +675,11 @@ def test_schedule_swapped_max_token_budget():
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(token_budget=1) budget = create_token_budget(token_budget=1)
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1 assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1 assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2 assert budget.num_curr_seqs == 2
...@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget(): ...@@ -706,8 +689,8 @@ def test_schedule_swapped_max_token_budget():
# Verify num_batched_tokens are respected. # Verify num_batched_tokens are respected.
budget = create_token_budget(token_budget=1) budget = create_token_budget(token_budget=1)
add_token_budget(budget, 1, 0) add_token_budget(budget, 1, 0)
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1 assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1 assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 0 assert budget.num_curr_seqs == 0
...@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget(): ...@@ -717,8 +700,6 @@ def test_schedule_swapped_max_token_budget():
def test_schedule_swapped_max_seqs(): def test_schedule_swapped_max_seqs():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(4): for i in range(4):
...@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs(): ...@@ -726,11 +707,11 @@ def test_schedule_swapped_max_seqs():
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget(max_num_seqs=2) budget = create_token_budget(max_num_seqs=2)
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2 assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2 assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2 assert budget.num_curr_seqs == 2
...@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs(): ...@@ -738,8 +719,8 @@ def test_schedule_swapped_max_seqs():
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
# Verify num_curr_seqs are respected. # Verify num_curr_seqs are respected.
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
remaining_swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2 assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 2 assert budget.num_batched_tokens == 2
assert budget.num_curr_seqs == 2 assert budget.num_curr_seqs == 2
...@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs(): ...@@ -750,8 +731,6 @@ def test_schedule_swapped_max_seqs():
def test_schedule_swapped_max_loras(): def test_schedule_swapped_max_loras():
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config) scheduler = initialize_scheduler(lora_config=lora_config)
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras: Set[int] = set() curr_loras: Set[int] = set()
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(2): for i in range(2):
...@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras(): ...@@ -764,11 +743,11 @@ def test_schedule_swapped_max_loras():
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 1 assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1 assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 1 assert budget.num_curr_seqs == 1
...@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras(): ...@@ -779,8 +758,6 @@ def test_schedule_swapped_max_loras():
def test_schedule_swapped_cannot_swap_in(): def test_schedule_swapped_cannot_swap_in():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
...@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in(): ...@@ -788,15 +765,15 @@ def test_schedule_swapped_cannot_swap_in():
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out. # The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
# Since we cannot swap in, none of the requests are swapped in. # Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 2 assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0 assert budget.num_curr_seqs == 0
...@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in(): ...@@ -806,8 +783,6 @@ def test_schedule_swapped_cannot_swap_in():
def test_infeasible_swap(): def test_infeasible_swap():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
...@@ -815,15 +790,15 @@ def test_infeasible_swap(): ...@@ -815,15 +790,15 @@ def test_infeasible_swap():
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out. # The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
# Since we cannot swap in, none of the requests are swapped in. # Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0 assert len(remaining_swapped) == 0
assert len(output.infeasible_seq_groups) == 2 assert len(output.infeasible_seq_groups) == 2
assert budget.num_batched_tokens == 0 assert budget.num_batched_tokens == 0
...@@ -834,23 +809,21 @@ def test_infeasible_swap(): ...@@ -834,23 +809,21 @@ def test_infeasible_swap():
def test_schedule_swapped_blocks_to_copy(): def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) scheduler._add_seq_group_to_swapped(seq_group)
# The last request should be swapped out. # The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock() scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = [(2, 3)] scheduler.block_manager.append_slots.return_value = [(2, 3)]
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( output = scheduler._schedule_swapped(budget, curr_loras)
swapped, budget, curr_loras, policy) remaining_swapped = scheduler.swapped
assert len(remaining_swapped) == 0 assert len(remaining_swapped) == 0
assert len(output.decode_seq_groups) == 1 assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
......
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. """Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run: Run:
```sh ```sh
cd $VLLM_PATH/tests cd $VLLM_PATH/tests
TEST_DIST_MODEL=facebook/opt-125m pytest \ pytest distributed/test_basic_distributed_correctness.py
distributed/test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
distributed/test_basic_distributed_correctness.py
``` ```
""" """
import os import os
...@@ -19,27 +14,48 @@ import pytest ...@@ -19,27 +14,48 @@ import pytest
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import fork_new_process_for_each_test
MODELS = [ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
@pytest.mark.skipif(cuda_device_count_stateless() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(
@pytest.mark.parametrize("dtype", ["half"]) "model, distributed_executor_backend, attention_backend, test_suite", [
@pytest.mark.parametrize("max_tokens", [5]) ("facebook/opt-125m", "ray", "", "L4"),
("facebook/opt-125m", "mp", "", "L4"),
("meta-llama/Llama-2-7b-hf", "ray", "", "L4"),
("meta-llama/Llama-2-7b-hf", "mp", "", "L4"),
("facebook/opt-125m", "ray", "", "A100"),
("facebook/opt-125m", "mp", "", "A100"),
("facebook/opt-125m", "mp", "FLASHINFER", "A100"),
("meta-llama/Meta-Llama-3-8B", "ray", "FLASHINFER", "A100"),
])
@fork_new_process_for_each_test
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str, distributed_executor_backend: str,
max_tokens: int, attention_backend: str,
test_suite: str,
) -> None: ) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
if test_suite != TARGET_TEST_SUITE:
pytest.skip(f"Skip test for {test_suite}")
if model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
# test ray adag
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
if attention_backend:
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
dtype = "half"
max_tokens = 5
# NOTE: take care of the order. run vLLM first, and then run HF. # NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization. # vLLM needs a fresh new process without cuda initialization.
......
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. """Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run: Run:
```sh ```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \ pytest test_chunked_prefill_distributed.py
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
``` ```
""" """
import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import fork_new_process_for_each_test
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
@pytest.mark.skipif(cuda_device_count_stateless() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model, distributed_executor_backend", [
@pytest.mark.parametrize("dtype", ["half"]) ("facebook/opt-125m", "ray"),
@pytest.mark.parametrize("max_tokens", [5]) ("meta-llama/Llama-2-7b-hf", "ray"),
@pytest.mark.parametrize("chunked_prefill_token_size", [16]) ("facebook/opt-125m", "mp"),
("meta-llama/Llama-2-7b-hf", "mp"),
])
@fork_new_process_for_each_test
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str, distributed_executor_backend: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None: ) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
dtype = "half"
max_tokens = 5
chunked_prefill_token_size = 16
# Add a chunked prefill config. # Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256) max_num_seqs = min(chunked_prefill_token_size, 256)
......
"""Compare the outputs of HF and distributed vLLM when using greedy sampling. """Compare the outputs of HF and distributed vLLM when using greedy sampling.
The second test will hang if more than one test is run per command, so we need
to run the tests one by one. The solution is to pass arguments (model name) by
environment variables.
Run: Run:
```sh ```sh
TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf \ pytest -s -v test_multimodal_broadcast.py
test_multimodal_broadcast.py
TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct \
test_multimodal_broadcast.py
``` ```
""" """
import os
import pytest import pytest
from vllm.utils import cuda_device_count_stateless from vllm.utils import cuda_device_count_stateless
model = os.environ["TEST_DIST_MODEL"] from ..utils import fork_new_process_for_each_test
if model.startswith("llava-hf/llava"):
from ..models.test_llava import models, run_test @pytest.mark.skipif(cuda_device_count_stateless() < 2,
elif model.startswith("microsoft/Phi-3-vision"): reason="Need at least 2 GPUs to run the test.")
from ..models.test_phi3v import models, run_test @pytest.mark.parametrize("model, distributed_executor_backend", [
else: ("llava-hf/llava-1.5-7b-hf", "ray"),
raise NotImplementedError(f"Unsupported model: {model}") ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"),
("llava-hf/llava-1.5-7b-hf", "mp"),
("llava-hf/llava-v1.6-mistral-7b-hf", "mp"),
@pytest.mark.parametrize("tensor_parallel_size", [2]) ])
@pytest.mark.parametrize("dtype", ["half"]) @fork_new_process_for_each_test
@pytest.mark.parametrize("max_tokens", [128]) def test_models(hf_runner, vllm_runner, image_assets, model: str,
@pytest.mark.parametrize("num_logprobs", [5]) distributed_executor_backend: str) -> None:
def test_models(hf_runner, vllm_runner, image_assets,
tensor_parallel_size: int, dtype: str, max_tokens: int, dtype = "half"
num_logprobs: int) -> None: max_tokens = 5
if cuda_device_count_stateless() < tensor_parallel_size: num_logprobs = 5
pytest.skip( tensor_parallel_size = 2
f"Need at least {tensor_parallel_size} GPUs to run the test.")
if model.startswith("llava-hf/llava-1.5"):
distributed_executor_backend = os.getenv("DISTRIBUTED_EXECUTOR_BACKEND") from ..models.test_llava import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
run_test( run_test(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
image_assets, image_assets,
model=models[0], model=models[0],
size_factors=[1.0], # So that LLaVA-NeXT processor may return nested list
size_factors=[0.25, 0.5, 1.0],
dtype=dtype, dtype=dtype,
max_tokens=max_tokens, max_tokens=max_tokens,
num_logprobs=num_logprobs, num_logprobs=num_logprobs,
......
from typing import Any, Dict
import pytest
import torch
from vllm.distributed.parallel_state import (_split_tensor_dict,
_update_nested_dict)
def test_split_tensor_dict():
test_dict = {
"key_a": "a",
"key_b": torch.arange(8, dtype=torch.float32),
"key_c": {
"key_1": torch.arange(5, dtype=torch.float32),
"key_2": torch.tensor([], dtype=torch.float32),
"key_3": 123,
},
"key_d": {},
}
metadata_list, tensor_list = _split_tensor_dict(test_dict)
assert len(metadata_list) == 6
assert torch.allclose(tensor_list[0], test_dict["key_b"])
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])
def test_split_tensor_dict_invalid_key():
test_dict = {
"a%b": "a",
}
with pytest.raises(AssertionError):
_split_tensor_dict(test_dict)
def test_update_nested_dict():
flattened_keys_values = [("key1%key2%key3", "value1"),
("key1%key2%key4", "value2"),
("key1%key5", "value3"), ("key6%key7", "value4"),
("key8", "value5")]
res: Dict[str, Any] = {}
for flat_key, value in flattened_keys_values:
_update_nested_dict(res, flat_key, value)
assert res == {
"key1": {
"key2": {
"key3": "value1",
"key4": "value2"
},
"key5": "value3"
},
"key6": {
"key7": "value4"
},
"key8": "value5"
}
"""
WARNING: This test runs in both single-node (4 GPUs) and multi-node
(2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is
important to set the distributed backend to "mp" to avoid Ray scheduling
all workers in a node other than the head node, which can cause the test
to fail.
"""
import os import os
import pytest import pytest
from ..utils import compare_two_settings from ..utils import compare_two_settings, fork_new_process_for_each_test
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize( @pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND", "MODEL_NAME, DIST_BACKEND"),
[ [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
]) ])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND): DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp": if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
USE_RAY_ADAG_NCCL = 0
USE_RAY_ADAG = 0
pp_args = [ pp_args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -59,5 +69,40 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, ...@@ -59,5 +69,40 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE: if EAGER_MODE:
pp_args.append("--enforce-eager") pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager") tp_args.append("--enforce-eager")
pp_env = None
if USE_RAY_ADAG:
assert DIST_BACKEND == "ray", (
"Ray ADAG is only supported with Ray distributed backend")
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
str(int(USE_RAY_ADAG_NCCL)),
}
compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [
(2, "JackFram/llama-160m"),
])
@pytest.mark.parametrize("ATTN_BACKEND", [
"FLASH_ATTN",
"FLASHINFER",
])
@fork_new_process_for_each_test
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
cudagraph_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--distributed-executor-backend",
"mp",
]
os.environ["VLLM_ATTENTION_BACKEND"] = ATTN_BACKEND
eager_args = cudagraph_args + ["--enforce-eager"]
compare_two_settings(MODEL_NAME, pp_args, tp_args) compare_two_settings(MODEL_NAME, eager_args, cudagraph_args)
import os
import pytest
from vllm.distributed.utils import get_pp_indices
def test_custom_layer_partition():
def _verify(partition_str, num_layers, pp_size, goldens):
bak = os.environ.get("VLLM_PP_LAYER_PARTITION", None)
os.environ["VLLM_PP_LAYER_PARTITION"] = partition_str
for pp_rank, golden in enumerate(goldens):
assert get_pp_indices(num_layers, pp_rank, pp_size) == golden
if bak is not None:
os.environ["VLLM_PP_LAYER_PARTITION"] = bak
# Even partition
_verify("5,5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
# Balanced partition
_verify("4,6,6,4", 20, 4, [(0, 4), (4, 10), (10, 16), (16, 20)])
# Put reminder somewhere
_verify("5,6,5,6", 22, 4, [(0, 5), (5, 11), (11, 16), (16, 22)])
# Invalid partition strings
with pytest.raises(ValueError):
_verify("5,5,5,5,", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
with pytest.raises(ValueError):
_verify("5,5,5,a", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
# Wrong number of partitions
with pytest.raises(ValueError):
_verify("5,5,5", 20, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
# Wrong number of layers
with pytest.raises(ValueError):
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
import pytest import pytest
@pytest.fixture
def sample_prompts():
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
@pytest.fixture
def sample_token_ids():
return [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
@pytest.fixture @pytest.fixture
def sample_regex(): def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
...@@ -66,4 +86,4 @@ column: "col_1" | "col_2" ...@@ -66,4 +86,4 @@ column: "col_1" | "col_2"
table: "table_1" | "table_2" table: "table_1" | "table_2"
condition: column "=" number condition: column "=" number
number: "1" | "2" number: "1" | "2"
""") """)
\ No newline at end of file
import json
import re
import weakref
import jsonschema
import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
...@@ -295,14 +295,19 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -295,14 +295,19 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
async for chunk in stream: async for chunk in stream:
assert chunk.usage is None assert chunk.usage is None
# Test stream=True, stream_options={"include_usage": True} # Test stream=True, stream_options={"include_usage": True,
stream = await client.chat.completions.create( # "continuous_usage_stats": False}}
model=model_name, stream = await client.chat.completions.create(model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.0, temperature=0.0,
stream=True, stream=True,
stream_options={"include_usage": True}) stream_options={
"include_usage":
True,
"continuous_usage_stats":
False
})
async for chunk in stream: async for chunk in stream:
if chunk.choices[0].finish_reason is None: if chunk.choices[0].finish_reason is None:
...@@ -338,6 +343,25 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, ...@@ -338,6 +343,25 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
stream=False, stream=False,
stream_options={"include_usage": True}) stream_options={"include_usage": True})
# Test stream=True, stream_options={"include_usage": True,
# "continuous_usage_stats": True}
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats": True
},
)
async for chunk in stream:
assert chunk.usage.prompt_tokens >= 0
assert chunk.usage.completion_tokens >= 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
# NOTE: Not sure why, but when I place this after `test_guided_regex_chat` # NOTE: Not sure why, but when I place this after `test_guided_regex_chat`
# (i.e. using the same ordering as in the Completions API tests), the test # (i.e. using the same ordering as in the Completions API tests), the test
......
...@@ -55,8 +55,9 @@ def zephyr_pa_files(): ...@@ -55,8 +55,9 @@ def zephyr_pa_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files,
args = [ zephyr_pa_files):
return [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
...@@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files): ...@@ -85,7 +86,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files, zephyr_pa_files):
"128", "128",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@pytest.fixture(scope="module")
def server(default_server_args):
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server yield remote_server
...@@ -537,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI): ...@@ -537,6 +541,28 @@ async def test_logits_bias(client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text assert first_response != completion.choices[0].text
@pytest.mark.asyncio
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 1
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Test exclusive selection
allowed_ids = [21555, 21557, 21558]
completion = await client.completions.create(
model=MODEL_NAME,
prompt=prompt,
max_tokens=max_tokens,
temperature=0.0,
seed=42,
extra_body=dict(allowed_token_ids=allowed_ids),
logprobs=1,
)
response_tokens = completion.choices[0].logprobs.tokens
assert len(response_tokens) == 1
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment