Unverified Commit 80e9afb5 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[V1][Core] Support for Structured Outputs (#12388)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1e3598ed
......@@ -204,6 +204,7 @@ steps:
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/structured_output
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting
......
......@@ -197,7 +197,7 @@ _build/
hip_compat.h
# Benchmark dataset
benchmarks/*.json
benchmarks/**/*.json
# Linting
actionlint
......
# SPDX-License-Identifier: Apache-2.0
"""Benchmark guided decoding throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import datasets
import pandas as pd
import uvloop
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
schema: dict
structure_type: str = 'json'
completion: str = None
def run_vllm(requests: list[SampleRequest],
engine_args: EngineArgs,
n: int,
guided_decoding_rate: float = 1.0,
warmup: bool = False) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**vars(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[str] = []
sampling_params: list[SamplingParams] = []
# create a list containing random selected true or false
guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate))
if warmup:
print(">>>>> Running warmup prompt, for the first 5")
# We setup the first 5 requests to warmup FSM
# if using xgrammar dataset, we will skip warmup
warmup_requests = requests[:5]
for i, request in enumerate(warmup_requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(json=request.schema)
if guided_decoding_rate > 0 else None,
))
llm.generate(prompts, sampling_params, use_tqdm=False)
print(">>>>> Benchmark started...")
prompts = []
sampling_params = []
for i, request in enumerate(requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(
**{request.structure_type: request.schema})
if i in guided_decoding_req_idx else None,
))
start = time.perf_counter()
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
ret = []
for output, request in zip(outputs, requests):
generated_text = output.outputs[0].text
ret.append({
"generated": generated_text,
"expected": request.completion
})
end = time.perf_counter()
return end - start, ret
async def run_vllm_async(
requests: list[SampleRequest],
engine_args: AsyncEngineArgs,
n: int,
guided_decoding_rate: float = 1.0,
warmup: bool = False,
disable_frontend_multiprocessing: bool = False) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
assert all(
llm.model_config.max_model_len >= (request.prompt_len +
request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[str] = []
sampling_params: list[SamplingParams] = []
guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate))
if warmup:
print(">>>>>> Running warmup prompt, for the first 5")
# We setup the first 5 requests to warmup FSM
# if using xgrammar dataset, we will skip warmup
warmup_requests = requests[:5]
for i, request in enumerate(warmup_requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(
json=request.schema)
if guided_decoding_rate > 0 else None,
))
generators = []
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
print(">>>>> Benchmark started...")
prompts = []
sampling_params = []
for i, request in enumerate(requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(json=request.schema)
if i in guided_decoding_req_idx else None,
))
generators = []
start_time = []
latencies = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
start_time.append(time.perf_counter())
latencies.append([])
all_gens = merge_async_iterators(*generators)
generated_texts = [''] * len(requests)
async for i, res in all_gens:
generated_texts[i] = res.outputs[0].text
lat = time.perf_counter() - start_time[i]
latencies[i].append(lat)
ret = [{
'generated': gt,
'expected': req.completion
} for gt, req in zip(generated_texts, requests)]
end = time.perf_counter()
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
for lat in latencies])
return end - start, ret, (first_latency, next_latency)
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> list[SampleRequest]:
if args.dataset == 'json':
if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(dir_path,
"structured_schemas",
"structured_schema_1.json")
with open(args.json_schema_path) as f:
schema = json.load(f)
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "grammar":
schema = """
?start: select_statement
?select_statement: "SELECT " column_list " FROM " table_name
?column_list: column_name ("," column_name)*
?table_name: identifier
?column_name: identifier
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
"""
prompt = "Generate an SQL query to show the 'username' \
and 'email' from the 'users' table."
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "regex":
regex = r"\w+@\w+\.com\n"
args.regex = regex
prompt = "Generate an email address for Alan Turing, \
who works in Enigma. End in .com and new line. \
Example result: alan.turing@enigma.com\n"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=regex,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "choice":
choice = ["Positive", "Negative"]
args.choice = choice
prompt = "Classify this sentiment: vLLM is wonderful!"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=choice,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "xgrammar_bench":
args.warmup = False
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train")
print(f"dataset has {len(dataset)} entries")
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
while idx >= len_dataset:
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
requests.append(
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
completion=completion))
return requests
def evaluate(ret, args):
def _eval_correctness_json(expected, actual):
# extract json string from string using regex
import re
actual = actual.replace('\n', '').replace(' ', '').strip()
try:
actual = re.search(r'\{.*\}', actual).group()
actual = json.loads(actual)
except Exception:
return False
return True
def _eval_correctness_choice(expected, actual):
return actual in args.choice
def _eval_correctness_regex(expected, actual):
import re
return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual):
if args.structure_type == 'json':
return _eval_correctness_json(expected, actual)
elif args.structure_type == 'regex':
return _eval_correctness_regex(expected, actual)
elif args.structure_type == 'choice':
return _eval_correctness_choice(expected, actual)
else:
return None
scores = []
for res in ret:
score = _eval_correctness(res['expected'], res['generated'])
res['correctness'] = score
scores.append(score)
not_none_scores = [score for score in scores if score is not None]
return (sum(not_none_scores) / len(not_none_scores) *
100) if len(not_none_scores) > 0 else None
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# async engine is working for 'regex', 'choice' and 'grammar'
if args.dataset == 'grammar':
args.structure_type = 'grammar'
args.async_engine = False
elif args.dataset == 'regex':
args.structure_type = 'regex'
args.async_engine = False
elif args.dataset == 'choice':
args.structure_type = 'choice'
args.async_engine = False
else:
args.structure_type = 'json'
if args.no_guided_decoding:
args.guided_decoding_ratio = 0
if args.save_results:
result_file_name = f'{args.guided_decoding_ratio}guided'
result_file_name += f"_{args.model.split('/')[-1]}"
result_file_name += f"_{args.dataset}"
result_file_name += f"_{args.num_prompts}"
result_file_name += f"_out{args.output_len}"
result_file_name += f"_async{args.async_engine}"
result_file_name += f"_warmup{args.warmup}"
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
result_file_name += ".txt"
else:
result_file_name = None
# Synthesize a prompt with the given input length.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(tokenizer, args)
if args.async_engine:
engine_args = AsyncEngineArgs.from_cli_args(args)
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
run_vllm_async(requests, engine_args, args.n,
args.guided_decoding_ratio, args.warmup,
args.disable_frontend_multiprocessing))
else:
engine_args = EngineArgs.from_cli_args(args)
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
args.guided_decoding_ratio, args.warmup)
first_latency, next_latency = None, None
score = evaluate(ret, args)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
if first_latency is not None:
latency_breakdown = "\nFirst token latency(msecs):\n"
latency_breakdown += f"{first_latency.describe()}"
latency_breakdown += "\nNext token latency(msecs):\n"
latency_breakdown += f"{next_latency.describe()}"
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
f"Correct rate is {score} %",
f"{latency_breakdown if first_latency is not None else ''}")
# Output JSON results if specified
if args.output_json or result_file_name:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"total_output_tokens": total_output_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
"output_tokens_per_second":
f"{total_output_tokens / elapsed_time:.2f}",
"correct_rate(%)": score
}
results = {"outputs": ret, **results}
if first_latency is not None:
results["first_token_latency(msecs)"] = first_latency.describe(
).to_dict()
results["next_token_latency(msecs)"] = next_latency.describe(
).to_dict()
if args.output_json:
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
elif result_file_name:
with open(result_file_name, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--output-len",
type=int,
default=512,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument(
"--dataset",
default='json',
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
parser.add_argument("--json_schema_path",
type=str,
default=None,
help="Path to json schema.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=10,
help="Number of prompts to process.")
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--no-guided-decoding",
action='store_true',
default=False,
help="Whether to disable JSON decoding or not.")
parser.add_argument("--guided-decoding-ratio",
type=float,
default=1.0,
help="Ratio of Guided Decoding requests")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--warmup",
action="store_true",
default=False,
help="Run warmup prompts before benchmark.")
parser.add_argument("--save-results",
action="store_true",
default=False,
help="save output results.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
main(args)
# SPDX-License-Identifier: Apache-2.0
r"""Benchmark online serving throughput with guided decoding.
r"""Benchmark online serving throughput with structured outputs.
On the server side, run one of the following commands:
(vLLM OpenAI API server)
......@@ -9,12 +9,12 @@ On the server side, run one of the following commands:
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving_guided.py \
python benchmarks/benchmark_serving_structured_output.py \
--backend <backend> \
--model <your_model> \
--dataset json \
--guided-decoding-ratio 1.0 \
--guided-decoding-backend xgrammar \
--structured-output-ratio 1.0 \
--structured-output-backend xgrammar \
--request-rate 10 \
--num-prompts 1000
......@@ -52,6 +52,9 @@ try:
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
......@@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train")
print(f"dataset has {len(dataset)} entries")
full_dataset_len = len(dataset)
def _filter_func(item):
import json
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features")
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
......@@ -378,8 +391,8 @@ async def benchmark(
selected_percentiles: list[str],
ignore_eos: bool,
max_concurrency: Optional[int],
guided_decoding_ratio: float,
guided_decoding_backend: str,
structured_output_ratio: float,
structured_output_backend: str,
goodput_config_dict: Optional[dict[str, float]] = None,
):
if backend in ASYNC_REQUEST_FUNCS:
......@@ -391,16 +404,18 @@ async def benchmark(
extra_body = {}
# Add the schema to the extra_body
extra_body[request.structure_type] = request.schema
# Add the specific guided_decoding_backend
extra_body["guided_decoding_backend"] = guided_decoding_backend
# Add the specific structured_output_backend
extra_body["guided_decoding_backend"] = structured_output_backend
return extra_body
print("Starting initial single prompt test run...")
guided_decoding_req_idx = random.sample(
structured_output_req_idx = random.sample(
range(len(input_requests)),
int(len(input_requests) * guided_decoding_ratio))
int(len(input_requests) * structured_output_ratio))
test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request)
if 0 in structured_output_req_idx else None)
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
......@@ -408,7 +423,7 @@ async def benchmark(
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request),
extra_body=test_req_extra_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
......@@ -427,7 +442,7 @@ async def benchmark(
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request),
extra_body=test_req_extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
......@@ -465,7 +480,7 @@ async def benchmark(
async for i, request in get_request(input_requests, request_rate,
burstiness):
extra_body = prepare_extra_body(
request) if i in guided_decoding_req_idx else None
request) if i in structured_output_req_idx else None
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
......@@ -708,10 +723,10 @@ def main(args: argparse.Namespace):
else:
args.structure_type = 'guided_json'
if args.no_guided_decoding:
args.guided_decoding_ratio = 0
if args.no_structured_output:
args.structured_output_ratio = 0
if args.save_results:
result_file_name = f'{args.guided_decoding_ratio}guided'
result_file_name = f'{args.structured_output_ratio}guided'
result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}"
......@@ -744,8 +759,8 @@ def main(args: argparse.Namespace):
],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
guided_decoding_ratio=args.guided_decoding_ratio,
guided_decoding_backend=args.guided_decoding_backend,
structured_output_ratio=args.structured_output_ratio,
structured_output_backend=args.structured_output_backend,
goodput_config_dict=goodput_config_dict,
))
......@@ -943,19 +958,19 @@ if __name__ == "__main__":
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
parser.add_argument("--no-guided-decoding",
parser.add_argument("--no-structured-output",
action='store_true',
default=False,
help="Whether to disable JSON decoding or not.")
parser.add_argument("--guided-decoding-ratio",
parser.add_argument("--structured-output-ratio",
type=float,
default=1.0,
help="Ratio of Guided Decoding requests")
parser.add_argument("--guided-decoding-backend",
help="Ratio of Structured Outputs requests")
parser.add_argument("--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar"],
default="xgrammar",
help="Backend to use for guided decoding")
help="Backend to use for structured outputs")
args = parser.parse_args()
main(args)
#!/bin/bash
# Define the model to use
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
# Define the backend to use
BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
GUIDED_RATIO=${6:-0.5}
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
# Define QPS values to test
QPS_VALUES=(70 60 50 25 20 15 10)
# Common parameters
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \
--structured-output-ratio $GUIDED_RATIO \
--save-results \
--result-dir $OUTPUT_DIR"
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------"
# Run benchmarks with different QPS values
for qps in "${QPS_VALUES[@]}"; do
echo "Running benchmark with QPS: $qps"
# Get git hash and branch for the filename
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \
--result-filename "$FILENAME" \
--port ${PORT:-8000}
echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------"
done
echo "All benchmarks completed!"
echo "Results saved to: $OUTPUT_DIR"
{
"$schema":
"https://json-schema.org/draft/2020-12/schema",
"title":
"User Profile",
"type":
"object",
"properties": {
"userId": {
"type": "string",
"description": "Unique identifier for the user."
},
"personalInfo": {
"type": "object",
"properties": {
"firstName": {
"type": "string",
"description": "The user's first name."
},
"lastName": {
"type": "string",
"description": "The user's last name."
},
"age": {
"type": "integer",
"minimum": 0,
"description": "The user's age."
},
"phoneNumbers": {
"type":
"array",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["home", "work", "mobile"],
"description": "Type of phone number."
},
"number": {
"type": "string",
"pattern": "^\\+?[1-9]\\d{1,14}$",
"description": "Phone number in E.164 format."
"name": { "type": "string" },
"race": { "type": "string" },
"class": { "type": "string" },
"level": { "type": "integer" },
"background": { "type": "string" },
"alignment": { "type": "string" },
"backstory": { "type": "string" }
},
"required": [
"name",
"race",
"class",
"level",
"background",
"alignment",
"backstory"
]
}
},
"required": ["type", "number"]
},
"description":
"List of phone numbers associated with the user."
}
},
"required": ["firstName", "lastName"]
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string",
"description": "Street address."
},
"city": {
"type": "string",
"description": "City name."
},
"state": {
"type": "string",
"description": "State or province."
},
"postalCode": {
"type": "string",
"pattern": "^\\d{5}(-\\d{4})?$",
"description": "Postal code."
},
"country": {
"type": "string",
"description": "Country name."
}
},
"required": ["street", "city", "state", "postalCode", "country"]
},
"preferences": {
"type": "object",
"properties": {
"newsletterSubscribed": {
"type":
"boolean",
"description":
"Indicates if the user is subscribed to the newsletter."
},
"favoriteCategories": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of user's favorite categories."
}
},
"required": ["newsletterSubscribed"]
},
"accountStatus": {
"type": "string",
"enum": ["active", "inactive", "suspended"],
"description": "Current status of the user's account."
},
"registrationDate": {
"type": "string",
"format": "date-time",
"description": "ISO 8601 formatted date-time of user registration."
}
},
"required":
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
}
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
......@@ -36,13 +37,21 @@ def create_scheduler(
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
cache_config.num_gpu_blocks = 10000
return Scheduler(scheduler_config,
return Scheduler(
scheduler_config,
model_config,
cache_config,
speculative_config=None,
lora_config=None,
log_stats=True)
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(
......@@ -249,7 +258,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -299,7 +310,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -347,7 +360,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -392,7 +407,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
......
......@@ -29,6 +29,7 @@ def sample_regex():
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
# Note: Ensure this only uses attributes compatible with xgrammar
@pytest.fixture
def sample_json_schema():
return {
......@@ -44,9 +45,7 @@ def sample_json_schema():
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
}
},
"work_history": {
"type": "array",
......@@ -71,8 +70,9 @@ def sample_json_schema():
}
# A schema unsupported by xgrammar
@pytest.fixture
def sample_complex_json_schema():
def unsupported_json_schema():
return {
"type": "object",
"properties": {
......@@ -150,7 +150,19 @@ def sample_guided_choice():
@pytest.fixture
def sample_sql_statements():
def sample_sql_ebnf():
return """
root ::= select_statement
select_statement ::= "SELECT" column "from" table "where" condition
column ::= "col_1" | "col_2"
table ::= "table_1" | "table_2"
condition ::= column "=" number
number ::= "1" | "2"
"""
@pytest.fixture
def sample_sql_lark():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
......
# SPDX-License-Identifier: Apache-2.0
import json
import jsonschema
import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_completion(monkeypatch, sample_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=unsupported_json_schema,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_lark,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_lark)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf_invalid(monkeypatch,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar="not a grammar",
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Regex guided decoding is not supported."):
llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
# Once regex is supported --
#assert outputs is not None
#for output in outputs:
# assert output is not None
# assert isinstance(output, RequestOutput)
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(generated_text)
# assert generated_text is not None
# assert re.fullmatch(sample_regex, generated_text) is not None
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
@pytest.fixture
def unsupported_string_schemas():
return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{
"type": "string",
"enum": ["active", "inactive", "pending"]
},
{
"type": "string",
"minLength": 1
},
{
"type": "string",
"maxLength": 100
},
{
"type": "string",
"format": "email"
},
]
@pytest.fixture
def unsupported_integer_schemas():
return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{
"type": "integer",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_number_schemas():
return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{
"type": "number",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_array_schemas():
return [
{
"type": "array",
"uniqueItems": True
},
{
"type": "array",
"contains": {
"type": "string"
}
},
{
"type": "array",
"minContains": 1
},
{
"type": "array",
"maxContains": 5
},
{
"type": "array",
"minItems": 1
},
{
"type": "array",
"maxItems": 10
},
]
@pytest.fixture
def unsupported_object_schemas():
return [
{
"type": "object",
"minProperties": 1
},
{
"type": "object",
"maxProperties": 5
},
{
"type": "object",
"propertyNames": {
"pattern": "^[a-z]+$"
}
},
{
"type": "object",
"patternProperties": {
"^S": {
"type": "string"
}
}
},
]
@pytest.fixture
def supported_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"status": {
"type": "string"
},
"scores": {
"type": "array",
"items": {
"type": "number"
}
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string"
},
"city": {
"type": "string"
}
}
}
}
}
@pytest.mark.parametrize("schema_type", [
"unsupported_string_schemas", "unsupported_integer_schemas",
"unsupported_number_schemas", "unsupported_array_schemas",
"unsupported_object_schemas"
])
def test_unsupported_json_features_by_type(schema_type, request):
schemas = request.getfixturevalue(schema_type)
for schema in schemas:
assert has_xgrammar_unsupported_json_features(
schema), f"Schema should be unsupported: {schema}"
def test_supported_json_features(supported_schema):
assert not has_xgrammar_unsupported_json_features(
supported_schema), "Schema should be supported"
......@@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
......@@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
......@@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
......@@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)
......
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import asyncio
import concurrent
......@@ -8,6 +10,7 @@ import datetime
import enum
import gc
import getpass
import importlib
import importlib.metadata
import importlib.util
import inspect
......@@ -23,6 +26,7 @@ import tempfile
import threading
import time
import traceback
import types
import uuid
import warnings
import weakref
......@@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream:
return _current_stream
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
......@@ -1977,7 +1981,7 @@ class MemorySnapshot:
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
cuda_memory=self.cuda_memory - other.cuda_memory,
......@@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
type.__setattr__(cls, '__init__', wrapped_init)
return cls
class LazyLoader(types.ModuleType):
"""
LazyLoader module borrowed from Tensorflow
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
with a addition of "module caching".
Lazily import a module, mainly to avoid pulling in large dependencies.
Modules such as `xgrammar` might do additional side effects, so we
only want to use this when it is needed, delaying all eager effects
"""
def __init__(
self,
local_name: str,
parent_module_globals: dict[str, Any],
name: str,
):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module: types.ModuleType | None = None
super().__init__(str(name))
def _load(self) -> types.ModuleType:
# Import the target module and insert it into the parent's namespace
try:
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# The additional add to sys.modules
# ensures library is actually loaded.
sys.modules[self._local_name] = module
except ModuleNotFoundError as err:
raise err from None
# Update this object's dict so that if someone keeps a
# reference to the LazyLoader, lookups are efficient
# (__getattr__ is only called on lookups that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item: Any) -> Any:
if self._module is None:
self._module = self._load()
return getattr(self._module, item)
def __dir__(self) -> list[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time
from collections import deque
from collections.abc import Iterable
......@@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__)
......@@ -32,12 +35,14 @@ class Scheduler:
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
structured_output_manager: StructuredOutputManager,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.speculative_config = speculative_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
......@@ -97,7 +102,7 @@ class Scheduler:
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
def schedule(self) -> "SchedulerOutput":
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
......@@ -114,6 +119,14 @@ class Scheduler:
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
......@@ -184,6 +197,12 @@ class Scheduler:
# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
......@@ -219,6 +238,10 @@ class Scheduler:
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque()
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
......@@ -227,6 +250,16 @@ class Scheduler:
request = self.waiting[0]
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
waiting_structured_output_req = self.waiting.popleft()
waiting_for_fsm.appendleft(
waiting_structured_output_req)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
......@@ -281,6 +314,10 @@ class Scheduler:
break
self.waiting.popleft()
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
req_index += 1
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
self.request_scheduled(request, scheduled_timestamp)
......@@ -311,6 +348,10 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm:
self.waiting.extendleft(waiting_for_fsm)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
......@@ -331,6 +372,11 @@ class Scheduler:
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
len(self.running),
)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
......@@ -369,6 +415,8 @@ class Scheduler:
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
self.finished_req_ids = set()
......@@ -381,7 +429,7 @@ class Scheduler:
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
resumed_from_preemption: bool,
) -> "CachedRequestData":
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
......@@ -474,8 +522,8 @@ class Scheduler:
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
......@@ -565,6 +613,15 @@ class Scheduler:
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id,
new_token_ids,
)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request.
......
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
......@@ -17,20 +22,20 @@ class NewRequestData:
req_id: str
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: list["MultiModalKwargs"]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list["PlaceholderRange"]
sampling_params: "SamplingParams"
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[int]
num_computed_tokens: int
lora_request: Optional["LoRARequest"]
lora_request: Optional[LoRARequest]
@classmethod
def from_request(
cls,
request: "Request",
request: Request,
block_ids: list[int],
) -> "NewRequestData":
) -> NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
......@@ -60,11 +65,11 @@ class CachedRequestData:
@classmethod
def from_request(
cls,
request: "Request",
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[int],
) -> "CachedRequestData":
) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
......@@ -111,3 +116,9 @@ class SchedulerOutput:
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
......@@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)
......
......@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -61,6 +62,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
......@@ -69,6 +72,7 @@ class EngineCore:
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
# Setup MM Input Mapper.
......@@ -131,6 +135,9 @@ class EngineCore:
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.populate_cache(req)
self.scheduler.add_request(req)
......@@ -148,11 +155,24 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
......
......@@ -66,9 +66,7 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)
......
......@@ -4,7 +4,7 @@ import time
from collections.abc import Mapping
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
......@@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.structured_output.utils import validate_structured_output_request
class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_config)
# Multi-modal (huggingface) input mapper
self.mm_input_cache_client = MMInputCacheClient(model_config)
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
self.cache_config.enable_prefix_caching
def _validate_logprobs(
self,
......@@ -80,6 +82,8 @@ class Processor:
self,
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
......@@ -125,6 +129,21 @@ class Processor:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
validate_structured_output_request(params)
def process_inputs(
self,
request_id: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment