Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
...@@ -4,13 +4,22 @@ import torch ...@@ -4,13 +4,22 @@ import torch
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.core.specialized_manager import SlidingWindowManager from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec from vllm.v1.kv_cache_interface import SlidingWindowSpec
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False,
num_kv_cache_groups=1,
caching_hash_fn=lambda x: x)
def test_sliding_window_possible_cached_prefix(): def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec( sliding_window_spec = SlidingWindowSpec(
block_size=2, block_size=block_size,
num_kv_heads=1, num_kv_heads=1,
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
...@@ -19,7 +28,7 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -19,7 +28,7 @@ def test_sliding_window_possible_cached_prefix():
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length): def run_one_case(block_is_cached, expect_length):
block_hash_list = [ block_hash_list = [
...@@ -36,7 +45,9 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -36,7 +45,9 @@ def test_sliding_window_possible_cached_prefix():
i: block_pool.blocks[i + 10] i: block_pool.blocks[i + 10]
} }
computed_blocks = manager.find_longest_cache_hit(block_hash_list) computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block assert all(block == block_pool.null_block
...@@ -79,7 +90,7 @@ def test_sliding_window_remove_skipped_blocks(): ...@@ -79,7 +90,7 @@ def test_sliding_window_remove_skipped_blocks():
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
null_block_id = block_pool.null_block.block_id null_block_id = block_pool.null_block.block_id
...@@ -100,39 +111,35 @@ def test_sliding_window_remove_skipped_blocks(): ...@@ -100,39 +111,35 @@ def test_sliding_window_remove_skipped_blocks():
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
] ]
block_table = id_to_block_table(original_block_ids) block_table = id_to_block_table(original_block_ids)
removed = manager.remove_skipped_blocks(block_table, 0) manager.req_to_blocks["test"] = block_table
assert_block_id(removed, [])
manager.remove_skipped_blocks("test", 0)
assert_block_id(block_table, original_block_ids) assert_block_id(block_table, original_block_ids)
# 4 tokens are computed. Only token 0 is out of the sliding window. As # 4 tokens are computed. Only token 0 is out of the sliding window. As
# block 1000 also contains token 1 that is in the sliding window, block 1000 # block 1000 also contains token 1 that is in the sliding window, block 1000
# cannot be removed. # cannot be removed.
removed = manager.remove_skipped_blocks(block_table, 4) manager.remove_skipped_blocks("test", 4)
assert_block_id(removed, [])
assert_block_id(block_table, original_block_ids) assert_block_id(block_table, original_block_ids)
# 5 tokens are computed. Token 0 & 1 are out of the sliding window. # 5 tokens are computed. Token 0 & 1 are out of the sliding window.
# Block 1000 can be removed. # Block 1000 can be removed.
removed = manager.remove_skipped_blocks(block_table, 5) manager.remove_skipped_blocks("test", 5)
assert_block_id(removed, [original_block_ids[0]])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 6 tokens are computed. Token 0-2 are out of the sliding window. # 6 tokens are computed. Token 0-2 are out of the sliding window.
# Cannot remove new block as the block 1001 is still used by token 3. # Cannot remove new block as the block 1001 is still used by token 3.
removed = manager.remove_skipped_blocks(block_table, 6) manager.remove_skipped_blocks("test", 6)
assert_block_id(removed, [])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 7 tokens are computed. Token 0-3 are out of the sliding window. # 7 tokens are computed. Token 0-3 are out of the sliding window.
# Block 1001 can be removed and block 1000 is already removed. # Block 1001 can be removed and block 1000 is already removed.
removed = manager.remove_skipped_blocks(block_table, 7) manager.remove_skipped_blocks("test", 7)
assert_block_id(removed, [original_block_ids[1]])
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# 11 tokens are computed. Token 0-7 are out of the sliding window. # 11 tokens are computed. Token 0-7 are out of the sliding window.
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer # Block 1002 & 1003 can be removed now. Block 1003 represents a longer
# sequence, and is expected to be evicted earlier than 1002, so the order # sequence, and is expected to be evicted earlier than 1002, so the order
# of removed blocks should be [1003, 1002]. # of removed blocks should be [1003, 1002].
removed = manager.remove_skipped_blocks(block_table, 11) manager.remove_skipped_blocks("test", 11)
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])
...@@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, ...@@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from ...distributed.conftest import publisher_config, random_port # noqa: F401
from tests.v1.engine.utils import FULL_STRINGS # isort: skip from tests.v1.engine.utils import FULL_STRINGS # isort: skip
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]] EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
......
...@@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest: ...@@ -40,6 +40,7 @@ def make_request() -> EngineCoreRequest:
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import os
import signal
import time import time
import uuid import uuid
from threading import Thread from threading import Thread
from typing import Optional from typing import Optional
import psutil
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -19,7 +21,9 @@ from vllm.v1.engine.core import EngineCore ...@@ -19,7 +21,9 @@ from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient) SyncMPClient)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda(): if not current_platform.is_cuda():
...@@ -43,6 +47,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: ...@@ -43,6 +47,7 @@ def make_request(params: SamplingParams) -> EngineCoreRequest:
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
cache_salt=None,
) )
...@@ -198,81 +203,175 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): ...@@ -198,81 +203,175 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
log_stats=True, log_stats=True,
) )
MAX_TOKENS = 20 try:
params = SamplingParams(max_tokens=MAX_TOKENS) MAX_TOKENS = 20
"""Normal Request Cycle.""" params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)] requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests] request_ids = [req.request_id for req in requests]
# Add requests to the engine. # Add requests to the engine.
for request in requests: for request in requests:
await client.add_request_async(request) await client.add_request_async(request)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
outputs: dict[str, list] = {req_id: [] for req_id in request_ids} outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs) await loop_until_done_async(client, outputs)
for req_id in request_ids: for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, ( assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}") f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle.""" """Abort Request Cycle."""
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
# Add requests to the engine. core_client: AsyncMPClient = client
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids} result = await core_client.call_utility_async("echo", "testarg")
await loop_until_done_async(client, outputs) assert result == "testarg"
for idx, req_id in enumerate(request_ids): with pytest.raises(Exception) as e_info:
if idx % 2 == 0: await core_client.call_utility_async("echo", None, "help!")
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}") assert str(e_info.value) == "Call to echo method failed: help!"
else: finally:
assert len(outputs[req_id]) == MAX_TOKENS, ( client.shutdown()
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
core_client: AsyncMPClient = client
result = await core_client.call_utility_async("echo", "testarg") @pytest.mark.parametrize(
assert result == "testarg" "multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
indirect=["publisher_config"],
)
def test_kv_cache_events(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with pytest.raises(Exception) as e_info: with monkeypatch.context() as m:
await core_client.call_utility_async("echo", None, "help!") m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
assert str(e_info.value) == "Call to echo method failed: help!" engine_args = EngineArgs(model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
@pytest.mark.timeout(10) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)
client.add_request(request)
outputs: dict[str, list] = {request.request_id: []}
loop_until_done(client, outputs)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event")
event = received.events[0]
assert isinstance(
event, BlockStored), ("We should have a BlockStored event")
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes")
assert event.block_size == block_size, (
"Block size should be the same as the block size")
assert event.parent_block_hash is None, (
"Parent block hash should be None")
assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens")
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens")
finally:
client.shutdown()
@pytest.mark.timeout(20)
def test_startup_failure(monkeypatch: pytest.MonkeyPatch): def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, pytest.raises(Exception) as e_info: with monkeypatch.context() as m, pytest.raises(Exception) as e_info:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
# Monkey-patch to extract core process pid while it's starting.
core_proc_pid = [None]
cepm_ctor = CoreEngineProcManager.__init__
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
cepm_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.processes[0].pid
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
t = time.time()
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
print(f"VllmConfig creation took {time.time() - t:.2f} seconds.")
# Start another thread to wait for engine core process to start # Start another thread to wait for engine core process to start
# and kill it - simulate fatal uncaught process exit. # and kill it - simulate fatal uncaught process exit.
this_proc = psutil.Process()
children_before = set(this_proc.children())
def kill_first_child(): def kill_first_child():
while True: while (child_pid := core_proc_pid[0]) is None:
time.sleep(0.5) time.sleep(0.5)
children = set(this_proc.children()) - children_before print(f"Killing child core process {child_pid}")
if children: assert isinstance(child_pid, int)
child = children.pop() os.kill(child_pid, signal.SIGKILL)
print("Killing child core process", child.pid)
child.kill()
break
Thread(target=kill_first_child, daemon=True).start() Thread(target=kill_first_child, daemon=True).start()
......
...@@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -57,6 +57,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
...@@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -403,6 +404,7 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
...@@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -503,7 +505,7 @@ def test_stop_token(include_stop_str_in_output: bool,
reason should be "stop" (i.e. first control token causes stop reason should be "stop" (i.e. first control token causes stop
and is represented in output text) and is represented in output text)
* else, the detokenized string should be * else, the detokenized string should be
<token><token>...<token> and the finish reason should be "stop" <token><token>...<token> and the finish reason should be "stop"
(i.e. first control token causes stop but is not represented (i.e. first control token causes stop but is not represented
in output text.) in output text.)
...@@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -565,6 +567,7 @@ def test_stop_token(include_stop_str_in_output: bool,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
...@@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -661,6 +664,7 @@ def test_stop_string(include_stop_str_in_output: bool,
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams( sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
...@@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -774,6 +778,7 @@ def test_iteration_stats(dummy_test_vectors):
mm_placeholders=None, mm_placeholders=None,
eos_token_id=None, eos_token_id=None,
lora_request=None, lora_request=None,
cache_salt=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
......
...@@ -72,12 +72,16 @@ def sample_json_schema(): ...@@ -72,12 +72,16 @@ def sample_json_schema():
"type": "string" "type": "string"
} }
}, },
"required": ["company", "duration", "position"] "required": ["company", "duration", "position"],
} "additionalProperties": False
},
"minItems": 0,
"maxItems": 3
} }
}, },
"required": "required":
["name", "age", "skills", "grade", "email", "work_history"] ["name", "age", "skills", "grade", "email", "work_history"],
"additionalProperties": False
} }
...@@ -100,7 +104,8 @@ def unsupported_json_schema(): ...@@ -100,7 +104,8 @@ def unsupported_json_schema():
} }
} }
}, },
"required": ["score", "tags"] "required": ["score", "tags"],
"additionalProperties": False
} }
...@@ -139,7 +144,8 @@ def sample_definition_json_schema(): ...@@ -139,7 +144,8 @@ def sample_definition_json_schema():
}, },
'required': ['steps', 'final_answer'], 'required': ['steps', 'final_answer'],
'title': 'MathReasoning', 'title': 'MathReasoning',
'type': 'object' 'type': 'object',
"additionalProperties": False
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment