Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import uuid
from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorMetadata) KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from .utils import create_request, create_scheduler, create_vllm_config from .utils import create_request, create_scheduler, create_vllm_config
def test_basic_inferface(): def test_basic_interface():
"""Unit test for basic NixlConnector interface functionality.""" """Unit test for basic NixlConnector interface functionality."""
vllm_config = create_vllm_config() vllm_config = create_vllm_config()
...@@ -25,7 +35,7 @@ def test_basic_inferface(): ...@@ -25,7 +35,7 @@ def test_basic_inferface():
scheduler.add_request(request) scheduler.add_request(request)
# Remote Prefill, triggers NixlConnectorMetdata. # Remote Prefill, triggers NixlConnectorMetadata.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
...@@ -72,3 +82,292 @@ def test_prompt_less_than_block_size(): ...@@ -72,3 +82,292 @@ def test_prompt_less_than_block_size():
# This request should be scheduled regularly. # This request should be scheduled regularly.
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
class FakeNixlConnectorWorker(NixlConnectorWorker):
REMOTE_ENGINE_ID = "remote_engine"
def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
remote_agent_name = self.add_remote_agent(
NixlAgentMetadata(
engine_id=self.REMOTE_ENGINE_ID,
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=self.block_len,
attn_backend_name=self.backend_name,
),
remote_tp_size=remote_tp_size)
return {0: remote_agent_name}
class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_multi_xfer_one_engine(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test case where multiple xfers are initiated to the same engine.
This test triggers the connector to load remote KV for the same
`request_id`. The transfer is not done immediately due to
`set_cycles_before_xfer_done`, so there is a state where there are
multiple transfer states for the same `request_id`, and `get_finished`
should handle it correctly (wait for all transfers to be done).
"""
vllm_config = create_vllm_config()
request_id = "req_id"
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
assert isinstance(connector.connector_worker.nixl_wrapper,
FakeNixlWrapper)
connector.connector_worker.nixl_wrapper.set_cycles_before_xfer_done(3)
num_xfers = 4
while True:
# For the same request_id, initiate multiple xfers across different
# round of `execute_model` calls.
metadata = NixlConnectorMetadata()
if num_xfers > 0:
num_xfers -= 1
metadata.add_new_req(
request_id=request_id,
local_block_ids=[
num_xfers + 1, num_xfers + 2, num_xfers + 3
],
kv_transfer_params={
"remote_block_ids":
[num_xfers + 4, num_xfers + 5, num_xfers + 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host":
"localhost",
"remote_port":
1234,
"remote_tp_size":
1,
})
connector.bind_connector_metadata(metadata)
# Mimic maybe_setup_kv_connector in gpu_model_runner.
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
# Mimic get_finished_kv_transfers in gpu_model_runner.
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
assert request_id in done_recving
break
connector.clear_connector_metadata()
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
(1, 1),
(2, 1),
(4, 2),
(4, 4),
])
def test_async_load_kv(
self,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
metadata.add_new_req(request_id="id",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
})
connector.bind_connector_metadata(metadata)
timeout = 2.5
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
return
raise TimeoutError("Took too long to complete async handshake.")
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_concurrent_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
"""Test that multiple start_load_kv calls should occur concurrently."""
vllm_config = create_vllm_config()
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id)
metadata = NixlConnectorMetadata()
total_reqs = 5
for i in range(total_reqs):
metadata.add_new_req(request_id=f"id_{i}",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id":
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)
timeout = 2.5 * total_reqs
cnt_finished_reqs = 0
start = time.perf_counter()
while time.perf_counter() - start < timeout:
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
_before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx)
_after_load = time.perf_counter()
assert _after_load - _before_load < 0.1, "start_load_kv took " \
f"{_after_load - _before_load} seconds"
time.sleep(0.5) # backoff for the async handshake to complete.
connector.bind_connector_metadata(NixlConnectorMetadata())
_, done_recving = connector.get_finished(finished_req_ids=set())
if len(done_recving) > 0:
cnt_finished_reqs += len(done_recving)
if cnt_finished_reqs == total_reqs:
return
raise TimeoutError("Took too long to complete async handshake.")
...@@ -66,7 +66,7 @@ def test_basic_lifecycle(): ...@@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert len(scheduler_output.finished_req_ids) == 1 assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model() # (2b): execute_model()
...@@ -81,7 +81,7 @@ def test_basic_lifecycle(): ...@@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model() # (3b): execute_model()
......
...@@ -36,7 +36,7 @@ def test_basic_lifecycle(): ...@@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output. # Nothing running and empty scheduler output.
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0 assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0 assert scheduler_output.total_num_scheduled_tokens == 0
...@@ -158,7 +158,7 @@ def test_interleaved_lifecycle(): ...@@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b]) [request_local_a, request_local_b])
...@@ -169,7 +169,7 @@ def test_interleaved_lifecycle(): ...@@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b]) reqs=[request_local_a, request_local_b])
...@@ -177,14 +177,14 @@ def test_interleaved_lifecycle(): ...@@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive. # STEP 4: KVs arrive.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b], [request_local_a, request_local_b],
...@@ -196,7 +196,7 @@ def test_interleaved_lifecycle(): ...@@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote]) [request_local_a, request_local_b, request_remote])
......
...@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0 assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0
...@@ -150,6 +149,7 @@ def create_request( ...@@ -150,6 +149,7 @@ def create_request(
request_id=f"id-{request_id}", request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=None, multi_modal_inputs=None,
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
...@@ -183,6 +183,7 @@ def create_model_runner_output( ...@@ -183,6 +183,7 @@ def create_model_runner_output(
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
) )
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from collections.abc import Callable
from typing import NamedTuple, Optional, Union
import numpy as np
import pytest
import torch
from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits,
create_penalty_tensor,
create_prompt_tokens_tensor,
fake_apply_logitsprocs,
fake_update_logitsprocs_state)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available
# yapf: disable
from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder,
LogitBiasLogitsProcessor,
LogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
MoveDirectionality,
init_builtin_logitsprocs)
# yapf: enable
from vllm.v1.sample.metadata import SamplingMetadata
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [
f"{current_platform.device_type}:{i}"
for i in range(1 if current_platform.device_count() == 1 else 2)
]
MAX_NUM_PROMPT_TOKENS = 64
MIN_TOKENS_LEN_THRESHOLD = 5
REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none"
# LogitsProcessor subclass or "none"
LogitprocType = Union[type[LogitsProcessor], str]
class LogitsProcsRequestParams:
"""Encapsulates key params for a single request in a batch.
Params can be customized based on the enabled logitproc
"""
workload_index: int
logitproc_type: LogitprocType # Logitproc enabled, specified by str id
out_tokens: list[int] # Output tokens required for min tokens test
params: SamplingParams # Settings customized for logitproc
def __init__(self, workload_index: int, logitproc_type: LogitprocType):
self.workload_index = workload_index
self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values
# don't matter *for these tests* so use 0 as a dummy value
self.out_tokens = ([0] *
(MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)))
self.params = _sampling_params_from_logitproc(logitproc_type)
def __str__(self):
"""For debugging"""
summ = ', '.join(f'{k}={v}' for k, v in vars(self).items())
return f"MyClass({summ})"
def _generate_fake_sampling_metadata(
num_output_tokens: int,
batch_size: int,
vocab_size: int,
device: torch.device,
) -> SamplingMetadata:
"""Generate fake sampling metadata with fake logitsprocs"""
output_token_ids: list[list[int]] = []
prompt_token_ids: list[list[int]] = []
for _ in range(batch_size):
output_token_ids.append(
np.random.randint(0, vocab_size, size=num_output_tokens).tolist())
prompt_token_ids.append(
np.random.randint(0,
vocab_size,
size=np.random.randint(
1, MAX_NUM_PROMPT_TOKENS)).tolist())
logitsprocs = init_builtin_logitsprocs(
pin_memory_available=PIN_MEMORY_AVAILABLE,
max_num_reqs=MAX_NUM_REQS + 1,
device=device)
fake_sampling_metadata = SamplingMetadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
top_p=None,
top_k=None,
generators={},
max_num_logprobs=0,
prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
frequency_penalties=create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=logitsprocs)
return fake_sampling_metadata
def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes:
"""Generate fake logits and sampling metadata"""
fake_logits = create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch, to support min-p test
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _generate_fake_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
return LogitsprocsTestFakes(
logits=fake_logits,
sampling_metadata=sampling_metadata,
)
def _sampling_params_from_logitproc(
logitproc_type: LogitprocType) -> SamplingParams:
"""Customize request SamplingParams for a specified logitproc"""
# SamplingParams for req with no logitproc
kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0}
if fxn := logitsprocs_test_mapping[logitproc_type].gen_request_fxn:
fxn(kwargs)
return SamplingParams(**kwargs)
def _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc: int,
logitsprocs_types: list[str],
) -> list[LogitsProcsRequestParams]:
"""Define key params for a batch of requests with a different
logitproc enabled per request.
The batch will have `reqs_per_logitproc` repeats for all
`logitsprocs_types` under test, including the case where
no logitsproc is enabled. The batch is randomly shuffled. The
size of the batch is `reqs_per_logitproc` times
`n = len(logitsprocs_types)`
Args:
reqs_per_logitproc: number of requests using each logitproc
logitsprocs_types: logitsprocs under test
Returns:
List of per-request params which configure the engine for that request's
enabled logitproc
"""
batch_size = len(logitsprocs_types) * reqs_per_logitproc
# Generate multiple repeats of key params for each logitproc;
# apply random inverse permutation to the iteration
# over logitsprocs, such that logitsprocs are shuffled.
batch_perm = random.sample(range(batch_size), k=batch_size)
return [
LogitsProcsRequestParams(
workload_index=idx,
logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc])
for idx, pdx in enumerate(batch_perm)
]
def _raise_error_invalid(
msg_suffix: str,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
err_cls: type[Exception] = ValueError,
) -> None:
raise err_cls(f"Validation failed for step={step_idx}, "
f"batch_index={batch_index}, "
f"workload_index={request_params.workload_index}, "
f"req_params={request_params}. Reason: {msg_suffix}")
def _logit_bias_params(kwargs: dict) -> None:
"""Logit bias config"""
kwargs["logit_bias"] = {
random.randint(0, VOCAB_SIZE - 1): random.choice([-0.1, 0.2])
}
def _logit_bias_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate logit bias logitproc applied correctly"""
logit_bias = request_params.params.logit_bias
logits_old = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
logits_new = logits_new[batch_index].cpu()
for token_id in range(VOCAB_SIZE):
logit_old_value = logits_old[token_id]
logit_new_value = logits_new[token_id]
if token_id in logit_bias:
bias_value = logit_bias[token_id]
exp_value = bias_value + logit_old_value
if logit_new_value != pytest.approx(exp_value):
_raise_error_invalid(msg_suffix=(
f"Biased token {token_id} logit value {logit_new_value} "
f"does not match expected value {exp_value} "
f"given bias {bias_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logit_new_value != pytest.approx(logit_old_value):
_raise_error_invalid(msg_suffix=(
f"Unbiased token {token_id} logit value {logit_new_value} "
f"does not match expected value {logit_old_value}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_p_params(kwargs: dict) -> None:
"""Min-p logitproc config"""
kwargs["min_p"] = 0.1
def _min_p_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-p logitproc applied correctly"""
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id == 0:
# Dominant token should always be unmasked
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix="Invalid: dominant token 0 masked (-inf)",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if request_params.params.min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: non-dominant token {token_id} not masked",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
# No masking when min_p is 0
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=
f"Invalid: token {token_id} masked when min_p=0.0",
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _min_tokens_params(kwargs: dict) -> None:
"""Min-tokens logitproc config"""
kwargs["min_tokens"] = MIN_TOKENS_LEN_THRESHOLD
kwargs["stop_token_ids"] = [
np.random.randint(0, VOCAB_SIZE - 1)
for _ in range(np.random.randint(0, VOCAB_SIZE))
]
def _min_tokens_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate min-tokens logitsproc applied correctly"""
ref_num_out_tokens = len(request_params.out_tokens)
min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD
ref_all_stop_token_ids = request_params.params.all_stop_token_ids
mt_lp: MinTokensLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor))
assert isinstance(mt_lp, MinTokensLogitsProcessor)
min_tok = mt_lp.min_toks.get(batch_index, None)
# Validate min-token logits processor state
if min_tok:
(_, out_tok, all_stop_token_ids) = min_tok
num_out_tokens = len(out_tok)
if num_out_tokens != ref_num_out_tokens:
_raise_error_invalid(msg_suffix=(
"Number of output tokens in min-token logit processor "
f"request metadata ({num_out_tokens}) does not match "
f"reference ({ref_num_out_tokens})."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if ref_all_stop_token_ids != all_stop_token_ids:
_raise_error_invalid(msg_suffix=(
"Stop token ids do not match reference; all_stop_token_ids: "
f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: "
f"{sorted(ref_all_stop_token_ids)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
if min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min reached, but batch "
"index is recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
elif not min_reached:
_raise_error_invalid(msg_suffix=(
"Expected min-tokens request with min not reached, but batch "
"index is not recognized by min-tokens logits processor."),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
err_cls=RuntimeError)
# Validate min-token logits
for token_id in range(VOCAB_SIZE):
logits_for_token = logits_new[batch_index][token_id]
if token_id in ref_all_stop_token_ids and not min_reached:
if logits_for_token != -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} is a stop token and "
"the sequence has not reached min length, "
"but the token is not masked "
f"(logit={logits_for_token})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
else:
if logits_for_token == -float("inf"):
_raise_error_invalid(
msg_suffix=(f"Token {token_id} should not be masked but "
f"is (output len={ref_num_out_tokens})"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
def _none_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate that no logits processors are applied"""
logits = (
test_fakes.logits[persistent_batch[batch_index].workload_index].cpu())
ref_logits = logits_new[batch_index]
if not torch.all(ref_logits == logits):
mismatch_toks = (ref_logits
!= logits).nonzero(as_tuple=True)[0].tolist()
mismatch_strs = []
for token in mismatch_toks:
val = float(logits[token])
ref_val = float(ref_logits[token])
mismatch_strs.append(f"({token=},{val=},{ref_val=})")
_raise_error_invalid(msg_suffix=(
f"Unexpected modification of logits: {','.join(mismatch_strs)}"),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
class LogitsprocTestHelpers(NamedTuple):
"""Supports setting up and validating logitsprocs unit tests."""
eval_fxn: Callable
gen_request_fxn: Optional[Callable] = None
logitsprocs_test_mapping = {
STR_NO_LOGITPROC:
LogitsprocTestHelpers(eval_fxn=_none_validate),
LogitBiasLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params,
eval_fxn=_logit_bias_validate),
MinPLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_p_params,
eval_fxn=_min_p_validate),
MinTokensLogitsProcessor:
LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params,
eval_fxn=_min_tokens_validate),
}
def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys())
return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC]
for logitproc_type in logitsprocs_types
if logitproc_type != STR_NO_LOGITPROC
] + [logitsprocs_types]
def _generate_fake_step_update(
persistent_batch: list[LogitsProcsRequestParams],
workload_params: list[LogitsProcsRequestParams],
wdx: int,
batch_update_builder: BatchUpdateBuilder,
) -> tuple[Optional[BatchUpdate], int, int]:
batch_size = len(persistent_batch)
workload_size = len(workload_params)
workload_reqs_remaining = workload_size - wdx
max_add_remove_per_step = max(1, int(0.2 * workload_size))
# 50% of steps: add no reqs
# Other 50%: add a limited number of reqs (less than the number
# of workload reqs remaining, less than an arbitrary max)
# If no workload reqs remain: 100% of steps have 0 adds
num_step_add = random.choice([
0,
random.randint(1, min(max_add_remove_per_step,
workload_reqs_remaining))
]) if workload_reqs_remaining else 0
# 50% of steps: remove no requests
# Other 50%: remove a limited number of reqs (less than the number
# persistent batch reqs remaining, less than an arbitrary max)
# If persistent batch is empty: 100% of steps have 0 removals until
# more requests are added. Assume that removed requests are always
# drawn from the current batch, before new adds
num_step_remove = random.choice([
0, random.randint(1, min(max_add_remove_per_step, batch_size))
]) if batch_size else 0
num_step_add_replace = min(num_step_add, num_step_remove)
# Generate fake removed request indices drawn from persistent batch indices
for removal in random.sample(range(batch_size), num_step_remove):
batch_update_builder.removed_append(removal)
# Get added requests from workload
for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]:
# Replace as many removed requests as possible with added requests
add_remove_idx = batch_update_builder.pop_removed()
batch_update_builder.added.append(
(add_remove_idx, add_req_params.params, add_req_params.out_tokens))
persistent_batch[add_remove_idx] = add_req_params
# Append remaining added requests to end of batch
add_reqs_append = workload_params[(wdx +
num_step_add_replace):(wdx +
num_step_add)]
batch_update_builder.added.extend([
(adx + batch_size, add_req_params.params, add_req_params.out_tokens)
for adx, add_req_params in enumerate(add_reqs_append)
])
persistent_batch.extend(add_reqs_append)
pre_condense_batch_size = len(persistent_batch)
wdx += num_step_add # Update workload offset
# Simulate condensing persistent batch
last_nonempty_index = pre_condense_batch_size - 1
condensed_to_idxs = set()
while batch_update_builder.removed:
if (last_nonempty_index in batch_update_builder.removed
or last_nonempty_index in condensed_to_idxs):
last_nonempty_index -= 1
continue
# last_nonempty_index is the highest persistent batch index that was
# not removed
first_empty_index = batch_update_builder.peek_removed()
assert first_empty_index is not None
if first_empty_index > last_nonempty_index:
break
# first_empty_index is the lowest removed persistent batch index
# that is less than last_nonempty_index
#
# move last_nonempty_index -> first_empty_index
batch_update_builder.pop_removed()
condensed_to_idxs.add(first_empty_index)
persistent_batch[first_empty_index] = persistent_batch[
last_nonempty_index]
batch_update_builder.moved.append(
(last_nonempty_index, first_empty_index,
MoveDirectionality.UNIDIRECTIONAL))
last_nonempty_index -= 1
# Now removed requests & gaps left by non-removed requests that got
# moved downward are grouped consecutively in the upper indices of
# the persistent batch. Truncate them to get condensed persistent batch
condensed_batch_size = batch_size + num_step_add - num_step_remove
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))
random.shuffle(idxs)
swaps = [
tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)
]
batch_update_builder.moved.extend([
(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps
])
for adx, bdx in swaps:
persistent_batch[adx], persistent_batch[bdx] = persistent_batch[
bdx], persistent_batch[adx]
return (batch_update_builder.get_and_reset(condensed_batch_size), wdx,
workload_size - wdx)
def _assert_valid(
batch_size: int,
persistent_batch: list[LogitsProcsRequestParams],
test_fakes: LogitsprocsTestFakes,
slice_idxs: list[int],
logits_w_lp: torch.Tensor,
step_idx: int,
) -> None:
if not slice_idxs:
# Trivial case of empty persistent batch
assert len(persistent_batch) == 0
if logits_w_lp.shape[0] != 0:
raise ValueError("Fake persistent batch is empty but logitsprocs "
f"output batch has shape {logits_w_lp.shape}")
return
# Validate logits for each fake request
for batch_index in range(batch_size):
request_params = persistent_batch[batch_index]
# Invoke the appropriate validation function for
# the logitproc employed by this request
fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn
fxn(test_fakes=test_fakes,
persistent_batch=persistent_batch,
logits_new=logits_w_lp,
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC])
@pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases())
def test_logitsprocs(device: str, reqs_per_logitproc: int,
logitsprocs_under_test: list[str]):
random.seed(40)
torch.set_default_device(device)
# Define a shuffled batch of requests which individually use a different
# logitproc, or no logitproc at all
workload_params = _generate_mixed_logitsprocs_batch_params(
reqs_per_logitproc=reqs_per_logitproc,
logitsprocs_types=logitsprocs_under_test)
workload_size = len(workload_params)
# Create fake test data structures for testing.
test_fakes = _generate_test_fakes(workload_size, device)
wdx = 0 # Next request index in workload to add
persistent_batch: list[LogitsProcsRequestParams] = [
] # Persistent batch state, as list of workload indices
# Generate fake removed request indices from current persistent
# batch before adds
batch_update_builder = BatchUpdateBuilder()
# Break when entire workload has been added previously and persistent
# batch is empty
workload_reqs_remaining = workload_size
batch_size = 0
step_idx = 0
while True:
if not (workload_reqs_remaining or batch_size):
break
(
batch_update,
wdx,
workload_reqs_remaining,
) = _generate_fake_step_update(
persistent_batch=persistent_batch,
workload_params=workload_params,
wdx=wdx,
batch_update_builder=batch_update_builder,
)
batch_size = len(persistent_batch)
# Apply fake batch update to logitsprocs
fake_update_logitsprocs_state(test_fakes, batch_update)
# Emulate application of logits processors in engine
slice_idxs = [req.workload_index for req in persistent_batch]
logits_w_lp = fake_apply_logitsprocs(test_fakes, slice_idxs).cpu()
_assert_valid(
batch_size=batch_size,
persistent_batch=persistent_batch,
test_fakes=test_fakes,
slice_idxs=slice_idxs,
logits_w_lp=logits_w_lp,
step_idx=step_idx,
)
step_idx += 1
...@@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62 ...@@ -13,9 +13,10 @@ EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported. # FIXME(rob): enable prefix caching once supported.
MODEL = "meta-llama/Llama-3.2-1B-Instruct" MODEL = "meta-llama/Llama-3.2-1B-Instruct"
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501 MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False,gpu_memory_utilization=0.8" # noqa: E501
SERVER_ARGS = [ SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests" "--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests",
"--gpu-memory-utilization=0.8"
] ]
NUM_CONCURRENT = 100 NUM_CONCURRENT = 100
...@@ -32,7 +33,7 @@ def test_prompt_logprobs_e2e(): ...@@ -32,7 +33,7 @@ def test_prompt_logprobs_e2e():
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
def test_promt_logprobs_e2e_server(): def test_prompt_logprobs_e2e_server():
with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server: with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server:
url = f"{remote_server.url_for('v1')}/completions" url = f"{remote_server.url_for('v1')}/completions"
......
...@@ -6,12 +6,14 @@ import pytest ...@@ -6,12 +6,14 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID, from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler) RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cuda" DEVICE = current_platform.device_type
@pytest.fixture @pytest.fixture
...@@ -21,7 +23,7 @@ def rejection_sampler(): ...@@ -21,7 +23,7 @@ def rejection_sampler():
def create_logits_tensor(output_token_ids: list[list[int]], def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor: vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that """Helper function to create logits tensor that
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids] token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
...@@ -41,8 +43,8 @@ def create_sampling_metadata( ...@@ -41,8 +43,8 @@ def create_sampling_metadata(
top_p: Optional[torch.Tensor] = None, top_p: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None, generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata: ) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling to the given value. Either all greedy or all random sampling
is used. is used.
""" """
generators = generators or {} generators = generators or {}
...@@ -57,7 +59,6 @@ def create_sampling_metadata( ...@@ -57,7 +59,6 @@ def create_sampling_metadata(
all_random=not all_greedy, all_random=not all_greedy,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
min_p=torch.empty(1, ),
generators=generators, generators=generators,
max_num_logprobs=0, max_num_logprobs=0,
no_penalties=False, no_penalties=False,
...@@ -66,10 +67,9 @@ def create_sampling_metadata( ...@@ -66,10 +67,9 @@ def create_sampling_metadata(
presence_penalties=torch.tensor([]), presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]), repetition_penalties=torch.tensor([]),
output_token_ids=[], output_token_ids=[],
min_tokens={},
logit_bias=[None],
allowed_token_ids_mask=None, allowed_token_ids_mask=None,
bad_words_token_ids={}, bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
) )
......
...@@ -8,10 +8,13 @@ import pytest ...@@ -8,10 +8,13 @@ import pytest
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler from vllm.v1.sample.sampler import Sampler
PIN_MEMORY_AVAILABLE = is_pin_memory_available()
MAX_NUM_REQS = 256
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
CUDA_DEVICES = [ CUDA_DEVICES = [
...@@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor( ...@@ -48,18 +51,6 @@ def _create_prompt_tokens_tensor(
) )
def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> list[Optional[dict[int, float]]]:
res: list[Optional[dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res
def _create_allowed_token_ids( def _create_allowed_token_ids(
batch_size: int, batch_size: int,
vocab_size: int, vocab_size: int,
...@@ -145,7 +136,6 @@ def _create_default_sampling_metadata( ...@@ -145,7 +136,6 @@ def _create_default_sampling_metadata(
all_random=False, all_random=False,
top_p=None, top_p=None,
top_k=None, top_k=None,
min_p=None,
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
...@@ -155,43 +145,13 @@ def _create_default_sampling_metadata( ...@@ -155,43 +145,13 @@ def _create_default_sampling_metadata(
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device), presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device), repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
no_penalties=True, no_penalties=True,
min_tokens={},
logit_bias=[None] * batch_size,
allowed_token_ids_mask=None, allowed_token_ids_mask=None,
bad_words_token_ids={}, bad_words_token_ids={},
logitsprocs=LogitsProcessorManager(),
) )
return fake_sampling_metadata return fake_sampling_metadata
def _generate_min_token_penalties_and_stop_tokens(
num_output_tokens: int, batch_size: int, vocab_size: int,
batch_indices_for_min_token_penalty: list[int]
) -> dict[int, tuple[int, set[int]]]:
"""
Generates and returns a dict of minimum token penalties and
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
min_tokens: dict[int, tuple[int, set[int]]] = {}
for index in range(batch_size):
if index in batch_indices_for_min_token_penalty:
min_tokens[index] = (
np.random.randint(num_output_tokens + 1,
2 * num_output_tokens),
set(
np.random.randint(0, vocab_size - 1)
for _ in range(np.random.randint(0, vocab_size))))
else:
min_tokens[index] = (np.random.randint(0,
num_output_tokens), set())
return min_tokens
def _create_weighted_output_token_list( def _create_weighted_output_token_list(
batch_size: int, batch_size: int,
vocab_size: int) -> tuple[list[list[int]], list[list[int]]]: vocab_size: int) -> tuple[list[list[int]], list[list[int]]]:
...@@ -227,36 +187,6 @@ def _create_weighted_output_token_list( ...@@ -227,36 +187,6 @@ def _create_weighted_output_token_list(
return output_token_ids, sorted_token_ids_in_output return output_token_ids, sorted_token_ids_in_output
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
batch_indices_for_min_token_penalty = np.random.randint(
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
min_tokens = _generate_min_token_penalties_and_stop_tokens(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
batch_indices_for_min_token_penalty)
sampling_metadata.min_tokens = min_tokens
sampler = Sampler()
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
if token_id in stop_token_ids:
assert logits[batch_idx][token_id] == -float("inf")
else:
assert logits[batch_idx][token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("presence_penalty", [-2.0, 2.0]) @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
...@@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, ...@@ -401,80 +331,6 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
or non_penalized_token_id in output_tokens) or non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("min_p", [0.0, 0.1])
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
# Create one dominant token per batch
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
# Configure min_p parameters
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)
sampler = Sampler()
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
logits = logits.cpu()
for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id == 0:
# Dominant token should always be unmasked
assert logits[batch_idx][token_id] != -float("inf")
else:
if min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
assert logits[batch_idx][token_id] == -float("inf")
else:
# No masking when min_p is 0
assert logits[batch_idx][token_id] != -float("inf")
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2]) @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
......
...@@ -2,25 +2,26 @@ ...@@ -2,25 +2,26 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch import torch
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
from torch import Generator from torch import Generator
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p, from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
is_flashinfer_available) is_flashinfer_available)
DEVICE = "cuda" DEVICE = current_platform.device_type
BATCH_SIZE = 1024 BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024 VOCAB_SIZE = 128 * 1024
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
if is_flashinfer_available:
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_default_device(): def reset_default_device():
""" """
Explicitly set the default device, which can affect subsequent tests. Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem. Adding this fixture helps avoid this problem.
""" """
original_device = torch.get_default_device() original_device = torch.get_default_device()
...@@ -28,7 +29,7 @@ def reset_default_device(): ...@@ -28,7 +29,7 @@ def reset_default_device():
torch.set_default_device(original_device) torch.set_default_device(original_device)
def test_topk_impl_equivalance(): def test_topk_impl_equivalence():
torch.set_default_device(DEVICE) torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33) generator = Generator(device=DEVICE).manual_seed(33)
...@@ -58,8 +59,8 @@ def test_flashinfer_sampler(): ...@@ -58,8 +59,8 @@ def test_flashinfer_sampler():
This test verifies that the FlashInfer top-k and top-p sampling This test verifies that the FlashInfer top-k and top-p sampling
implementation produces the same results as the Python implementation. implementation produces the same results as the Python implementation.
NOTE: FlashInfer did not directly expose an interface for fused top-k and NOTE: FlashInfer did not directly expose an interface for fused top-k and
top-p prob renorm (it did provide fused sampling but we cannot compare top-p prob renorm (it did provide fused sampling but we cannot compare
sampling results due to randomness), so we will compare the probability sampling results due to randomness), so we will compare the probability
renormed consequently by top-k and then top-p of FlashInfer implementation. renormed consequently by top-k and then top-p of FlashInfer implementation.
''' '''
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from enum import Enum from enum import Enum
from typing import Optional from typing import NamedTuple, Optional
import regex as re import regex as re
import torch
from vllm import CompletionOutput from vllm import CompletionOutput
from vllm.utils import make_tensor_with_pad
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor
from vllm.v1.sample.metadata import SamplingMetadata
class BatchLogprobsComposition(Enum): class BatchLogprobsComposition(Enum):
...@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob( ...@@ -134,3 +139,77 @@ def compute_correct_cumulative_logprob(
logprobs = completion_output.logprobs logprobs = completion_output.logprobs
assert logprobs is not None assert logprobs is not None
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)]) return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
def create_fake_logits(batch_size: int, vocab_size: int) -> torch.Tensor:
fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=torch.float)
return fake_logits
def create_penalty_tensor(batch_size: int, penalty_value: float,
device: torch.device) -> torch.Tensor:
return torch.full((batch_size, ),
fill_value=penalty_value,
dtype=torch.float,
device=device)
def create_prompt_tokens_tensor(
prompt_token_ids: list[list[int]],
vocab_size: int,
device: torch.device,
) -> torch.Tensor:
return make_tensor_with_pad(
prompt_token_ids,
pad=vocab_size,
device=device,
dtype=torch.int64,
pin_memory=False,
)
class LogitsprocsTestFakes(NamedTuple):
"""Wraps fake data structures to support testing"""
logits: torch.Tensor
sampling_metadata: SamplingMetadata
def get_logitsprocs_by_cls(
self,
cls: type[LogitsProcessor],
) -> Iterator[LogitsProcessor]:
"""Yield logits processors of a specific class.
Args:
cls: :class:`LogitsProcessor` subclass
Returns:
Iterator over logits processors
"""
return (lp for lp in self.sampling_metadata.logitsprocs.all
if isinstance(lp, cls))
def get_logitsprocs(self) -> Iterator[LogitsProcessor]:
"""Iterator over all logits processors."""
return self.sampling_metadata.logitsprocs.all
def fake_update_logitsprocs_state(
test_fakes: LogitsprocsTestFakes,
batch_update: BatchUpdate,
) -> None:
"""Imitate logits processors persistent batch state update
in engine core"""
for logitproc in test_fakes.get_logitsprocs():
logitproc.update_state(batch_update)
def fake_apply_logitsprocs(
test_fakes: LogitsprocsTestFakes,
slice_indices: list[int],
) -> torch.Tensor:
"""Imitate application of logits processors in engine core"""
logits = test_fakes.logits[torch.tensor(slice_indices,
dtype=torch.long)].clone()
for processor in test_fakes.get_logitsprocs():
logits = processor.apply(logits)
return logits
...@@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ...@@ -10,6 +10,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig) VllmConfig)
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.eagle import EagleProposer
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
...@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer: ...@@ -38,15 +39,17 @@ def _create_proposer(method: str, k: int) -> EagleProposer:
num_speculative_tokens=k, num_speculative_tokens=k,
) )
vllm_config = VllmConfig(model_config=model_config, vllm_config = VllmConfig(
cache_config=CacheConfig(), model_config=model_config,
speculative_config=speculative_config, cache_config=CacheConfig(),
device_config=DeviceConfig(device="cuda"), speculative_config=speculative_config,
parallel_config=ParallelConfig(), device_config=DeviceConfig(device=current_platform.device_type),
load_config=LoadConfig(), parallel_config=ParallelConfig(),
scheduler_config=SchedulerConfig()) load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config, device='cuda') return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
def test_prepare_inputs(): def test_prepare_inputs():
...@@ -59,7 +62,7 @@ def test_prepare_inputs(): ...@@ -59,7 +62,7 @@ def test_prepare_inputs():
a, a + 1, ..., a + b - n2 - 1, a, a + 1, ..., a + b - n2 - 1,
a + b, a + b + 1, ..., a + b + c - n3 - 1] a + b, a + b + 1, ..., a + b + c - n3 - 1]
""" """
device = torch.device('cuda') device = torch.device(current_platform.device_type)
# a = 4, b = 7, c = 5 # a = 4, b = 7, c = 5
# n1 = 1, n2 = 3, n3 = 2 # n1 = 1, n2 = 3, n3 = 2
...@@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, ...@@ -198,7 +201,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens): def test_propose(num_speculative_tokens):
# Use GPU device # Use GPU device
device = torch.device('cuda') device = torch.device(current_platform.device_type)
# Setup test parameters # Setup test parameters
batch_size = 2 batch_size = 2
......
...@@ -4,24 +4,30 @@ ...@@ -4,24 +4,30 @@
import asyncio import asyncio
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional from typing import Optional
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b", model="ibm-research/PowerMoE-3b",
enforce_eager=True, enforce_eager=True,
disable_log_requests=True, disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)), data_parallel_size=DP_SIZE,
) )
if not current_platform.supports_v1(engine_args.create_model_config()): if not current_platform.supports_v1(engine_args.create_model_config()):
...@@ -74,12 +80,32 @@ async def generate( ...@@ -74,12 +80,32 @@ async def generate(
async def test_load(output_kind: RequestOutputKind, async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str): data_parallel_backend: str):
stats_loggers = {}
@dataclass
class SimpleStatsLogger(StatLoggerBase):
init_count: int = 0
finished_req_count: int = 0
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)
def log_engine_initialized(self):
self.init_count += 1
with ExitStack() as after: with ExitStack() as after:
prompt = "This is a test of data parallel" prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown) after.callback(engine.shutdown)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
...@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids: for request_id in request_ids:
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
generate(engine, generate(engine, request_id, prompt, output_kind,
request_id, NUM_EXPECTED_TOKENS)))
prompt, # Short sleep to ensure that requests are distributed.
output_kind, await asyncio.sleep(0.01)
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
# Confirm that we got all the EXPECTED tokens from the requests. # Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks, done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION) return_when=asyncio.FIRST_EXCEPTION)
...@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind, ...@@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
assert not core_client.engines_running assert not core_client.engines_running
assert not core_client.reqs_in_flight assert not core_client.reqs_in_flight
# Check that requests were distributed between the engines
print(f"Stats loggers after test: {stats_loggers}")
assert len(stats_loggers) == DP_SIZE
assert stats_loggers[0].init_count == 1
for sl in stats_loggers.values():
slogger: SimpleStatsLogger = sl
assert slogger.finished_req_count > NUM_REQUESTS // (
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for external LB testing
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
# Default tensor parallell size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
class ExternalLBServerManager:
"""Manages data parallel vLLM server instances for external
load balancer testing."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for external LB mode."""
for rank in range(self.dp_size):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
# Add external LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-rank",
str(rank),
"--data-parallel-size-local",
"1",
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + rank), # Different port for each rank
"--api-server-count",
str(self.api_server_count),
])
# Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]):
try:
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i))
for i in range(r * TP_SIZE, (r + 1) * TP_SIZE))
})
server.__enter__()
print(f"Server rank {r} started successfully with "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start server rank {r}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(rank, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(2)
if len(self.servers) != self.dp_size:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
api_server_count = request.param
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each server
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_single_completion(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each server
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(f"Server {i} handled single completion request successfully")
await asyncio.sleep(0.5)
# Send requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed external LB test with {len(clients)} servers "
f"(API server count: {api_server_count})")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each server
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(f"Server {i} handled single streaming request successfully")
await asyncio.sleep(0.5)
# Send streaming requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed external LB streaming test with "
f"{len(clients)} servers (API server count: {api_server_count})")
...@@ -12,8 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine ...@@ -12,8 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription "openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder "facebook/bart-large-cnn", # encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1", # mamba "state-spaces/mamba-130m-hf", # mamba1
"hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding "BAAI/bge-m3", # embedding
] ]
...@@ -74,12 +73,6 @@ def test_unsupported_configs(monkeypatch): ...@@ -74,12 +73,6 @@ def test_unsupported_configs(monkeypatch):
disable_async_output_proc=True, disable_async_output_proc=True,
).create_engine_config() ).create_engine_config()
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
scheduling_policy="priority",
).create_engine_config()
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
AsyncEngineArgs( AsyncEngineArgs(
model=MODEL, model=MODEL,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.request import RequestStatus
def test_request_status_fmt_str():
"""Test that the string representation of RequestStatus is correct."""
assert f"{RequestStatus.WAITING}" == "WAITING"
assert f"{RequestStatus.WAITING_FOR_FSM}" == "WAITING_FOR_FSM"
assert f"{RequestStatus.WAITING_FOR_REMOTE_KVS}" == "WAITING_FOR_REMOTE_KVS"
assert f"{RequestStatus.RUNNING}" == "RUNNING"
assert f"{RequestStatus.PREEMPTED}" == "PREEMPTED"
assert f"{RequestStatus.FINISHED_STOPPED}" == "FINISHED_STOPPED"
assert f"{RequestStatus.FINISHED_LENGTH_CAPPED}" == "FINISHED_LENGTH_CAPPED"
assert f"{RequestStatus.FINISHED_ABORTED}" == "FINISHED_ABORTED"
assert f"{RequestStatus.FINISHED_IGNORED}" == "FINISHED_IGNORED"
...@@ -67,6 +67,43 @@ def test_basic( ...@@ -67,6 +67,43 @@ def test_basic(
assert "1024" in output or "0, 1" in output assert "1024" in output or "0, 1" in output
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
max_tokens: int,
max_num_seqs: int,
) -> None:
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, by violating privacy",
" what is essential is love.",
" but in rising every time we fall.",
]
# test head dim = 96
model = "microsoft/Phi-3-mini-128k-instruct"
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_batched_tokens=256,
max_num_seqs=max_num_seqs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
TP_SIZE_8 = 8 TP_SIZE_8 = 8
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
import torch_xla
import vllm.v1.attention.backends.pallas # noqa: F401
from vllm.platforms import current_platform
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a test for TPU only")
@pytest.mark.parametrize("page_size", [32, 33])
@pytest.mark.parametrize("combined_kv_head_num", [2, 16])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("num_slices_per_block", [4, 8])
def test_kv_cache_update_kernel(page_size: int, combined_kv_head_num: int,
head_dim: int, num_slices_per_block: int):
page_num = 1000
padded_num_tokens = 128
kv_cache_cpu = torch.zeros(
(page_num * page_size, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
kv_cache_xla = kv_cache_cpu.to(torch_xla.device())
new_kv_cpu = torch.randn(
(padded_num_tokens, combined_kv_head_num, head_dim),
dtype=torch.bfloat16,
device="cpu")
new_kv_xla = new_kv_cpu.to(torch_xla.device())
slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9],
dtype=np.int32)
num_kv_update_slices = len(slice_lens)
kv_cache_start_indices = np.array([
page_size * 2 - 7, page_size * 2, page_size * 3, page_size * 4 + 6,
page_size * 5 + 7, page_size * 6 + 8, page_size * 15 + 3
],
dtype=np.int32)
new_kv_cache_indices = np.concatenate(
[np.array([0], dtype=np.int32),
np.cumsum(slice_lens[:-1])])
slot_mapping = np.stack(
[kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1)
padded_size = (slot_mapping.shape[0] + num_slices_per_block -
1) // num_slices_per_block * num_slices_per_block
slot_mapping = np.pad(slot_mapping,
[[0, padded_size - slot_mapping.shape[0]], [0, 0]],
constant_values=0)
slot_mapping = np.transpose(slot_mapping)
slot_mapping_cpu = torch.tensor(slot_mapping,
device="cpu",
dtype=torch.int32)
slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device())
num_kv_update_slices_xla = torch.tensor([num_kv_update_slices],
device=torch_xla.device(),
dtype=torch.int32)
torch_xla.sync()
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True)
new_kv_cache_xla = torch.ops.xla.kv_cache_update_op(
new_kv_xla, slot_mapping_xla, kv_cache_xla, num_kv_update_slices_xla,
page_size, num_slices_per_block)
kv_cache_xla.copy_(new_kv_cache_xla)
torch_xla.sync()
for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices,
slice_lens):
kv_cache_cpu[ci:ci + sl, :, :] = new_kv_cpu[ni:ni + sl, :, :]
assert torch.allclose(kv_cache_xla.cpu(),
kv_cache_cpu,
atol=1e-4,
rtol=1e-4)
...@@ -47,7 +47,7 @@ def test_ragged_paged_attention(): ...@@ -47,7 +47,7 @@ def test_ragged_paged_attention():
key = torch.zeros(num_tokens, num_kv_heads * head_size) key = torch.zeros(num_tokens, num_kv_heads * head_size)
value = torch.zeros(num_tokens, num_kv_heads * head_size) value = torch.zeros(num_tokens, num_kv_heads * head_size)
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64) slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64)
max_num_reqs = 8 max_num_reqs = 8
max_num_blocks_per_req = 8 max_num_blocks_per_req = 8
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
...@@ -65,6 +65,7 @@ def test_ragged_paged_attention(): ...@@ -65,6 +65,7 @@ def test_ragged_paged_attention():
context_lens=context_lens, context_lens=context_lens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
num_seqs=num_seqs, num_seqs=num_seqs,
num_slices_per_kv_cache_update_block=8,
) )
with patch("torch.ops.xla.ragged_paged_attention" with patch("torch.ops.xla.ragged_paged_attention"
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import tempfile import tempfile
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile import tempfile
import numpy as np import numpy as np
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment