Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
"""An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command:
(single image inference with Llava)
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
(multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
--trust-remote-code --limit-mm-per-prompt image=2
"""
import base64
......@@ -84,3 +90,36 @@ chat_completion_from_base64 = client.chat.completions.create(
result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}")
# Multi-image input inference
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What are the animals in these images?"
},
{
"type": "image_url",
"image_url": {
"url": image_url_duck
},
},
{
"type": "image_url",
"image_url": {
"url": image_url_lion
},
},
],
}],
model=model,
max_tokens=64,
)
result = chat_completion_from_url.choices[0].message.content
print(f"Chat completion output:{result}")
{%- macro json_to_python_type(json_spec) %}
{%- set basic_type_map = {
"string": "str",
"number": "float",
"integer": "int",
"boolean": "bool"
} %}
{%- if basic_type_map[json_spec.type] is defined %}
{{- basic_type_map[json_spec.type] }}
{%- elif json_spec.type == "array" %}
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
{%- elif json_spec.type == "object" %}
{%- if json_spec.additionalProperties is defined %}
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
{%- else %}
{{- "dict" }}
{%- endif %}
{%- elif json_spec.type is iterable %}
{{- "Union[" }}
{%- for t in json_spec.type %}
{{- json_to_python_type({"type": t}) }}
{%- if not loop.last %}
{{- "," }}
{%- endif %}
{%- endfor %}
{{- "]" }}
{%- else %}
{{- "Any" }}
{%- endif %}
{%- endmacro %}
{{- bos_token }}
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
{%- if tools is iterable and tools | length > 0 %}
{%- for tool in tools %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{{- '{"type": "function", "function": ' }}
{{- '{"name": "' + tool.name + '", ' }}
{{- '"description": "' + tool.name + '(' }}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{{- param_name + ": " + json_to_python_type(param_fields) }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}
{%- if tool.return is defined %}
{{- " -> " + json_to_python_type(tool.return) }}
{%- endif %}
{{- " - " + tool.description + "\n\n" }}
{%- for param_name, param_fields in tool.parameters.properties|items %}
{%- if loop.first %}
{{- " Args:\n" }}
{%- endif %}
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
{%- endfor %}
{%- if tool.return is defined and tool.return.description is defined %}
{{- "\n Returns:\n " + tool.return.description }}
{%- endif %}
{{- '"' }}
{{- ', "parameters": ' }}
{%- if tool.parameters.properties | length == 0 %}
{{- "{}" }}
{%- else %}
{{- tool.parameters|tojson }}
{%- endif %}
{{- "}" }}
{%- if not loop.last %}
{{- "\n" }}
{%- endif %}
{%- endfor %}
{%- endif %}
{{- " </tools>" }}
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
' }}
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
" }}
{{- "<tool_call>
" }}
{{- '{"name": <function-name>, "arguments": <args-dict>}
' }}
{{- '</tool_call><|im_end|>' }}
{%- for message in messages %}
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" and message.tool_calls is defined %}
{{- '<|im_start|>' + message.role }}
{%- for tool_call in message.tool_calls %}
{{- '\n<tool_call>\n' }}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '{' }}
{{- '"name": "' }}
{{- tool_call.name }}
{{- '"}' }}
{{- ', ' }}
{%- if tool_call.arguments is defined %}
{{- '"arguments": ' }}
{{- tool_call.arguments|tojson }}
{%- endif %}
{{- '\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.previtem and loop.previtem.role != "tool" %}
{{- '<|im_start|>tool\n' }}
{%- endif %}
{{- '<tool_response>\n' }}
{{- message.content }}
{%- if not loop.last %}
{{- '\n</tool_response>\n' }}
{%- else %}
{{- '\n</tool_response>' }}
{%- endif %}
{%- if not loop.last and loop.nextitem.role != "tool" %}
{{- '<|im_end|>' }}
{%- elif loop.last %}
{{- '<|im_end|>' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] + eos_token }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}
{%- if messages[0]["role"] == "system" %}
{%- set system_message = messages[0]["content"] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- if tools is defined %}
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
{%- if system_message is defined %}
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
{%- else %}
{%- set system_message = parallel_tool_prompt %}
{%- endif %}
{%- endif %}
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
{%- endif %}
{%- endfor %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if message["role"] == "user" %}
{%- if tools is not none and (message == user_messages[-1]) %}
{{- "[AVAILABLE_TOOLS] [" }}
{%- for tool in tools %}
{%- set tool = tool.function %}
{{- '{"type": "function", "function": {' }}
{%- for key, val in tool.items() if key != "return" %}
{%- if val is string %}
{{- '"' + key + '": "' + val + '"' }}
{%- else %}
{{- '"' + key + '": ' + val|tojson }}
{%- endif %}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- "}}" }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "[/AVAILABLE_TOOLS]" }}
{%- endif %}
{%- if loop.last and system_message is defined %}
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
{%- else %}
{{- "[INST] " + message["content"] + "[/INST]" }}
{%- endif %}
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
{%- if message.tool_calls is defined %}
{%- set tool_calls = message.tool_calls %}
{%- else %}
{%- set tool_calls = message.content %}
{%- endif %}
{{- "[TOOL_CALLS] [" }}
{%- for tool_call in tool_calls %}
{%- set out = tool_call.function|tojson %}
{{- out[:-1] }}
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
{%- endif %}
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" + eos_token }}
{%- endif %}
{%- endfor %}
{%- elif message["role"] == "assistant" %}
{{- " " + message["content"] + eos_token }}
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
{%- if message.content is defined and message.content.content is defined %}
{%- set content = message.content.content %}
{%- else %}
{%- set content = message.content %}
{%- endif %}
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
{%- endif %}
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
{%- else %}
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
{%- endif %}
{%- endfor %}
......@@ -99,7 +99,6 @@ echo 'vLLM mypy:'
mypy --follow-imports skip # Note that this is less strict than CI
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/executor --follow-imports skip
......
......@@ -58,6 +58,7 @@ files = [
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/core",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
......
......@@ -20,9 +20,10 @@ lm-format-enforcer == 0.10.6
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
partial-json-parser # used for parsing partial JSON outputs
pyzmq
msgspec
librosa # Required for audio processing
soundfile # Required for audio processing
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
pyyaml
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
......@@ -8,3 +8,4 @@ botocore
ray >= 2.10.0
peft
pytest-asyncio
tensorizer>=2.9.0
\ No newline at end of file
......@@ -11,12 +11,14 @@ pytest-shard
# testing utils
awscli
einops # required for MPT and qwen-vl
einops # required for MPT, qwen-vl and Mamba
httpx
librosa # required for audio test
peft
requests
ray
sentence-transformers # required for embedding
soundfile # required for audio test
compressed-tensors==0.4.0 # required for compressed-tensors
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
......@@ -30,4 +32,4 @@ aiohttp
# quantization
bitsandbytes==0.42.0
buildkite-test-collector==0.1.8
\ No newline at end of file
buildkite-test-collector==0.1.8
......@@ -4,4 +4,4 @@
# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
ray[default]
......@@ -426,7 +426,8 @@ def get_vllm_version() -> str:
# version = find_version(get_path("vllm", "version.py"))
if _no_device():
version += "+empty"
if envs.VLLM_TARGET_DEVICE == "empty":
version += "+empty"
elif _is_cuda():
cuda_version = str(get_nvcc_cuda_version())
if cuda_version != MAIN_CUDA_VERSION:
......@@ -566,6 +567,7 @@ setup(
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer>=2.9.0"],
"audio": ["librosa", "soundfile"] # Required for audio processing
},
cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {},
package_data=package_data,
......
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from ..utils import VLLM_PATH, RemoteOpenAIServer
......@@ -31,9 +32,10 @@ def server():
yield remote_server
@pytest.fixture(scope="module")
def client(server):
return server.get_async_client()
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
......
......@@ -6,6 +6,7 @@ prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
from contextlib import nullcontext
import pytest
......@@ -15,18 +16,6 @@ MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
E5M2_KV_MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-chat-hf",
]
E4M3_KV_MODELS = [
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
]
KV_CACHE_QUANTIZATION_PATHS = {
"meta-llama/Llama-2-7b-chat-hf":
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
}
@pytest.mark.parametrize("model", MODELS)
......@@ -77,10 +66,10 @@ def test_models(
)
@pytest.mark.parametrize("kv_cache_dtype,model",
[("fp8_e5m2", m)
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
for m in E4M3_KV_MODELS])
@pytest.mark.parametrize(
"kv_cache_dtype,model",
[("fp8_e4m3",
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4])
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
......@@ -88,6 +77,9 @@ def test_models(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
# Due to low-precision numerical divergence, this test is too sensitive to
# the async postprocessor
@pytest.mark.parametrize("disable_async_output_proc", [True])
def test_models_with_fp8_kv_cache(
vllm_runner,
example_prompts,
......@@ -97,36 +89,25 @@ def test_models_with_fp8_kv_cache(
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
disable_async_output_proc: bool,
) -> None:
"""
Only checks log probs match between chunked-prefill and
non-chunked-prefill version of vLLM model runner.
This test is used when there is discrepancy in kernels
/ numerics (e.g. when using lower-precision types like FP8).
Check output logprobs match between no_chunked_prefill and chunked_prefill
with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py,
so here we only check chunked prefill.
"""
NUM_LOG_PROBS = 8
if model == "facebook/opt-125m":
pytest.skip(
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
)
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
extra_kwargs = {}
if model in KV_CACHE_QUANTIZATION_PATHS:
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
model]
with vllm_runner(
model,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
**extra_kwargs,
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)
......@@ -139,7 +120,7 @@ def test_models_with_fp8_kv_cache(
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
kv_cache_dtype=kv_cache_dtype,
**extra_kwargs,
disable_async_output_proc=disable_async_output_proc,
) as vllm_model:
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS)
......@@ -150,3 +131,68 @@ def test_models_with_fp8_kv_cache(
name_0="no_chunked_prefill",
name_1="chunked_prefill",
)
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
use_v2_block_manager: bool,
tensor_parallel_size: int,
) -> None:
"""
Checks exact match decode with and without prefix caching
with chunked prefill enabled.
"""
model = "meta-llama/Llama-2-7b-chat-hf"
# The common prompt has 142 tokens with Llama-2 tokenizer.
common_prompt = "You are a helpful AI assistant " * 20
unique_prompts = [
"Question", # Warmup
"Question", # Fully cached
"Another question", # Partial cached
]
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]
max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
tensor_parallel_size=tensor_parallel_size,
use_v2_block_manager=use_v2_block_manager,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as vllm_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += vllm_model.generate_greedy([prompt],
max_tokens)
# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
......@@ -212,7 +212,6 @@ def test_swap_infeasible(
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,
......
from typing import Optional
import torch
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
class MyMod(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
if cache is not None:
return x + cache
return x * 2
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
def __init__(self, model):
self.model = model
compiled_callable = torch.compile(self.forward, backend="eager")
super().__init__(compiled_callable)
def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# this is the function to be compiled
return self.model(x, cache)
def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None):
# let torch.compile compile twice
if len(self.compiled_codes) == 2:
dispatch_id = 0 if cache is None else 1
with self.dispatch_to_code(dispatch_id):
return self.forward(x, cache)
else:
return self.compiled_callable(x, cache)
def test_torch_compile_wrapper():
mod = MyMod()
wrappers = []
for i in range(3):
torch._dynamo.reset()
wrapper = MyWrapper(mod)
wrappers.append(wrapper)
x = torch.tensor([1])
wrapper(x, None) # profile run, compile
# create a cache tensor
cache = torch.tensor([2])
wrapper(x, cache) # warm up with cache, recompile
# for new input, dispatch to the compiled code directly
new_x = torch.tensor([3])
assert wrapper(new_x,
None).item() == 6 # dispatch to the first compiled code
assert wrapper(
new_x, cache).item() == 5 # dispatch to the second compiled code
for wrapper in wrappers:
# make sure they have independent compiled codes
assert len(wrapper.compiled_codes) == 2
......@@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
......@@ -161,7 +165,7 @@ def example_encoder_decoder_prompts(
decoder prompt) tuple.
Returns:
* Encoder prompt list
* Decoder prompt list (reverse of encoder prompt list)
'''
......@@ -205,8 +209,14 @@ class HfRunner:
def wrap_device(self, input: _T) -> _T:
if not is_cpu():
# Check if the input is already on the GPU
if hasattr(input, 'device') and input.device.type == "cuda":
return input # Already on GPU, no need to move
return input.to("cuda")
else:
# Check if the input is already on the CPU
if hasattr(input, 'device') and input.device.type == "cpu":
return input # Already on CPU, no need to move
return input.to("cpu")
def __init__(
......@@ -578,8 +588,7 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
images: Optional[PromptImageInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
......@@ -623,10 +632,8 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
......@@ -676,10 +683,8 @@ class VllmRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
......
......@@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator:
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99
# Test case for marking cache hit blocks as computed right after
# a batch of prefill sequences are scheduled.
@staticmethod
def test_touch_block():
block_size = 16
common_blocks = 4
allocator = PrefixCachingBlockAllocator(num_blocks=8,
block_size=block_size)
common_token_ids = list(range(block_size * common_blocks))
# Mimic the behavior of allocating the same block chain
# (i.e., common prefix) for a batch of 3 different prefill sequences.
for _ in range(3):
blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=common_token_ids,
allocator=allocator,
)
block_ids = [block.block_id for block in blocks]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
assert len(computed_block_ids) == 0
allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
assert len(computed_block_ids) == common_blocks
@staticmethod
def create_immutable_chain(
block_size: int,
......
......@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""
block_size = 4
num_cpu_blocks = 0
num_gpu_blocks = 16
block_manager = BlockSpaceManagerV1(block_size,
num_gpu_blocks,
num_cpu_blocks,
watermark=0,
enable_caching=True)
# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length = block_size * num_gpu_blocks - 1
# Allocate (reserve) all blocks.
_, seq_group = create_dummy_prompt("0",
prompt_length,
block_size=block_size)
block_manager.allocate(seq_group)
assert seq_group.seqs[0].n_blocks == num_gpu_blocks
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size = int(block_size * 2.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 2
# Actual computed tokens.
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)
# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size = int(block_size * 4.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 7
......@@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
......@@ -180,7 +180,7 @@ def test_maximal_decoding():
"""Verify decoding requests are prioritized."""
block_size = 4
max_seqs = 2
max_model_len = 2
max_model_len = 8
max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
......@@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
assert len(get_sequence_groups(out)) == max_seqs
assert not running[0].is_prefill()
assert not running[1].is_prefill()
def test_perfix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []
# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 50
# Verify it is chunked. Note that although the budget is 64-50=14,
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
# tokens are allocated.
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment