Unverified Commit 80e9afb5 authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[V1][Core] Support for Structured Outputs (#12388)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarNick Hill <nhill@redhat.com>
parent 1e3598ed
...@@ -204,6 +204,7 @@ steps: ...@@ -204,6 +204,7 @@ steps:
- VLLM_USE_V1=1 pytest -v -s v1/engine - VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample - VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker - VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/structured_output
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py - VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py - VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting # TODO: accuracy does not match, whether setting
......
...@@ -197,7 +197,7 @@ _build/ ...@@ -197,7 +197,7 @@ _build/
hip_compat.h hip_compat.h
# Benchmark dataset # Benchmark dataset
benchmarks/*.json benchmarks/**/*.json
# Linting # Linting
actionlint actionlint
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
r"""Benchmark online serving throughput with guided decoding. r"""Benchmark online serving throughput with structured outputs.
On the server side, run one of the following commands: On the server side, run one of the following commands:
(vLLM OpenAI API server) (vLLM OpenAI API server)
...@@ -9,12 +9,12 @@ On the server side, run one of the following commands: ...@@ -9,12 +9,12 @@ On the server side, run one of the following commands:
./launch_tgi_server.sh <your_model> <max_batch_total_tokens> ./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run: On the client side, run:
python benchmarks/benchmark_serving_guided.py \ python benchmarks/benchmark_serving_structured_output.py \
--backend <backend> \ --backend <backend> \
--model <your_model> \ --model <your_model> \
--dataset json \ --dataset json \
--guided-decoding-ratio 1.0 \ --structured-output-ratio 1.0 \
--guided-decoding-backend xgrammar \ --structured-output-backend xgrammar \
--request-rate 10 \ --request-rate 10 \
--num-prompts 1000 --num-prompts 1000
...@@ -52,6 +52,9 @@ try: ...@@ -52,6 +52,9 @@ try:
except ImportError: except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000 MILLISECONDS_TO_SECONDS_CONVERSION = 1000
...@@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, ...@@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
requests: list[SampleRequest] = [] requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval", dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train") split="train")
print(f"dataset has {len(dataset)} entries") full_dataset_len = len(dataset)
def _filter_func(item):
import json
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features")
len_dataset = len(dataset) len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts): for data_point_idx in range(args.num_prompts):
idx = data_point_idx idx = data_point_idx
...@@ -220,21 +233,21 @@ async def get_request( ...@@ -220,21 +233,21 @@ async def get_request(
burstiness: float = 1.0, burstiness: float = 1.0,
) -> AsyncGenerator[tuple[int, SampleRequest], None]: ) -> AsyncGenerator[tuple[int, SampleRequest], None]:
""" """
Asynchronously generates requests at a specified rate Asynchronously generates requests at a specified rate
with OPTIONAL burstiness. with OPTIONAL burstiness.
Args: Args:
input_requests: input_requests:
A list of input requests, each represented as a tuple. A list of input requests, each represented as a tuple.
request_rate: request_rate:
The rate at which requests are generated (requests/s). The rate at which requests are generated (requests/s).
burstiness (optional): burstiness (optional):
The burstiness factor of the request generation. The burstiness factor of the request generation.
Only takes effect when request_rate is not inf. Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process. Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution. Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests. (burstiness > 1) results in a more uniform arrival of requests.
""" """
input_requests = iter(input_requests) input_requests = iter(input_requests)
...@@ -378,8 +391,8 @@ async def benchmark( ...@@ -378,8 +391,8 @@ async def benchmark(
selected_percentiles: list[str], selected_percentiles: list[str],
ignore_eos: bool, ignore_eos: bool,
max_concurrency: Optional[int], max_concurrency: Optional[int],
guided_decoding_ratio: float, structured_output_ratio: float,
guided_decoding_backend: str, structured_output_backend: str,
goodput_config_dict: Optional[dict[str, float]] = None, goodput_config_dict: Optional[dict[str, float]] = None,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
...@@ -391,16 +404,18 @@ async def benchmark( ...@@ -391,16 +404,18 @@ async def benchmark(
extra_body = {} extra_body = {}
# Add the schema to the extra_body # Add the schema to the extra_body
extra_body[request.structure_type] = request.schema extra_body[request.structure_type] = request.schema
# Add the specific guided_decoding_backend # Add the specific structured_output_backend
extra_body["guided_decoding_backend"] = guided_decoding_backend extra_body["guided_decoding_backend"] = structured_output_backend
return extra_body return extra_body
print("Starting initial single prompt test run...") print("Starting initial single prompt test run...")
guided_decoding_req_idx = random.sample( structured_output_req_idx = random.sample(
range(len(input_requests)), range(len(input_requests)),
int(len(input_requests) * guided_decoding_ratio)) int(len(input_requests) * structured_output_ratio))
test_request = input_requests[0] test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request)
if 0 in structured_output_req_idx else None)
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=test_request.prompt, prompt=test_request.prompt,
...@@ -408,7 +423,7 @@ async def benchmark( ...@@ -408,7 +423,7 @@ async def benchmark(
prompt_len=test_request.prompt_len, prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len, output_len=test_request.expected_output_len,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request), extra_body=test_req_extra_body,
) )
test_output = await request_func(request_func_input=test_input) test_output = await request_func(request_func_input=test_input)
if not test_output.success: if not test_output.success:
...@@ -427,7 +442,7 @@ async def benchmark( ...@@ -427,7 +442,7 @@ async def benchmark(
prompt_len=test_request.prompt_len, prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len, output_len=test_request.expected_output_len,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request), extra_body=test_req_extra_body,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
...@@ -465,7 +480,7 @@ async def benchmark( ...@@ -465,7 +480,7 @@ async def benchmark(
async for i, request in get_request(input_requests, request_rate, async for i, request in get_request(input_requests, request_rate,
burstiness): burstiness):
extra_body = prepare_extra_body( extra_body = prepare_extra_body(
request) if i in guided_decoding_req_idx else None request) if i in structured_output_req_idx else None
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(
model=model_id, model=model_id,
prompt=request.prompt, prompt=request.prompt,
...@@ -708,10 +723,10 @@ def main(args: argparse.Namespace): ...@@ -708,10 +723,10 @@ def main(args: argparse.Namespace):
else: else:
args.structure_type = 'guided_json' args.structure_type = 'guided_json'
if args.no_guided_decoding: if args.no_structured_output:
args.guided_decoding_ratio = 0 args.structured_output_ratio = 0
if args.save_results: if args.save_results:
result_file_name = f'{args.guided_decoding_ratio}guided' result_file_name = f'{args.structured_output_ratio}guided'
result_file_name += f"_{backend}" result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}" result_file_name += f"_{args.model.split('/')[-1]}"
...@@ -744,8 +759,8 @@ def main(args: argparse.Namespace): ...@@ -744,8 +759,8 @@ def main(args: argparse.Namespace):
], ],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
guided_decoding_ratio=args.guided_decoding_ratio, structured_output_ratio=args.structured_output_ratio,
guided_decoding_backend=args.guided_decoding_backend, structured_output_backend=args.structured_output_backend,
goodput_config_dict=goodput_config_dict, goodput_config_dict=goodput_config_dict,
)) ))
...@@ -943,19 +958,19 @@ if __name__ == "__main__": ...@@ -943,19 +958,19 @@ if __name__ == "__main__":
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve") "and the blog: https://hao-ai-lab.github.io/blogs/distserve")
parser.add_argument("--no-guided-decoding", parser.add_argument("--no-structured-output",
action='store_true', action='store_true',
default=False, default=False,
help="Whether to disable JSON decoding or not.") help="Whether to disable JSON decoding or not.")
parser.add_argument("--guided-decoding-ratio", parser.add_argument("--structured-output-ratio",
type=float, type=float,
default=1.0, default=1.0,
help="Ratio of Guided Decoding requests") help="Ratio of Structured Outputs requests")
parser.add_argument("--guided-decoding-backend", parser.add_argument("--structured-output-backend",
type=str, type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar"], choices=["outlines", "lm-format-enforcer", "xgrammar"],
default="xgrammar", default="xgrammar",
help="Backend to use for guided decoding") help="Backend to use for structured outputs")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
#!/bin/bash
# Define the model to use
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
# Define the backend to use
BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
GUIDED_RATIO=${6:-0.5}
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
# Define QPS values to test
QPS_VALUES=(70 60 50 25 20 15 10)
# Common parameters
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \
--structured-output-ratio $GUIDED_RATIO \
--save-results \
--result-dir $OUTPUT_DIR"
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------"
# Run benchmarks with different QPS values
for qps in "${QPS_VALUES[@]}"; do
echo "Running benchmark with QPS: $qps"
# Get git hash and branch for the filename
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \
--result-filename "$FILENAME" \
--port ${PORT:-8000}
echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------"
done
echo "All benchmarks completed!"
echo "Results saved to: $OUTPUT_DIR"
{ {
"$schema": "type": "array",
"https://json-schema.org/draft/2020-12/schema", "items": {
"title": "type": "object",
"User Profile",
"type":
"object",
"properties": { "properties": {
"userId": { "name": { "type": "string" },
"type": "string", "race": { "type": "string" },
"description": "Unique identifier for the user." "class": { "type": "string" },
}, "level": { "type": "integer" },
"personalInfo": { "background": { "type": "string" },
"type": "object", "alignment": { "type": "string" },
"properties": { "backstory": { "type": "string" }
"firstName": {
"type": "string",
"description": "The user's first name."
},
"lastName": {
"type": "string",
"description": "The user's last name."
},
"age": {
"type": "integer",
"minimum": 0,
"description": "The user's age."
},
"phoneNumbers": {
"type":
"array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["home", "work", "mobile"],
"description": "Type of phone number."
},
"number": {
"type": "string",
"pattern": "^\\+?[1-9]\\d{1,14}$",
"description": "Phone number in E.164 format."
}
},
"required": ["type", "number"]
},
"description":
"List of phone numbers associated with the user."
}
},
"required": ["firstName", "lastName"]
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string",
"description": "Street address."
},
"city": {
"type": "string",
"description": "City name."
},
"state": {
"type": "string",
"description": "State or province."
},
"postalCode": {
"type": "string",
"pattern": "^\\d{5}(-\\d{4})?$",
"description": "Postal code."
},
"country": {
"type": "string",
"description": "Country name."
}
},
"required": ["street", "city", "state", "postalCode", "country"]
},
"preferences": {
"type": "object",
"properties": {
"newsletterSubscribed": {
"type":
"boolean",
"description":
"Indicates if the user is subscribed to the newsletter."
},
"favoriteCategories": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of user's favorite categories."
}
},
"required": ["newsletterSubscribed"]
},
"accountStatus": {
"type": "string",
"enum": ["active", "inactive", "suspended"],
"description": "Current status of the user's account."
},
"registrationDate": {
"type": "string",
"format": "date-time",
"description": "ISO 8601 formatted date-time of user registration."
}
}, },
"required": "required": [
["userId", "personalInfo", "address", "accountStatus", "registrationDate"] "name",
} "race",
\ No newline at end of file "class",
"level",
"background",
"alignment",
"backstory"
]
}
}
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional from typing import Optional
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
...@@ -36,13 +37,21 @@ def create_scheduler( ...@@ -36,13 +37,21 @@ def create_scheduler(
swap_space=0, swap_space=0,
cache_dtype="auto", cache_dtype="auto",
) )
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
cache_config.num_gpu_blocks = 10000 cache_config.num_gpu_blocks = 10000
return Scheduler(scheduler_config, return Scheduler(
model_config, scheduler_config,
cache_config, model_config,
speculative_config=None, cache_config,
lora_config=None, speculative_config=None,
log_stats=True) lora_config=None,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests( def create_requests(
...@@ -249,7 +258,9 @@ def test_stop_via_update_from_output(): ...@@ -249,7 +258,9 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[]) free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -299,7 +310,9 @@ def test_stop_via_update_from_output(): ...@@ -299,7 +310,9 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[]) free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -347,7 +360,9 @@ def test_stop_via_update_from_output(): ...@@ -347,7 +360,9 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[]) free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -392,7 +407,9 @@ def test_stop_via_update_from_output(): ...@@ -392,7 +407,9 @@ def test_stop_via_update_from_output():
}, },
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[]) free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
......
...@@ -29,6 +29,7 @@ def sample_regex(): ...@@ -29,6 +29,7 @@ def sample_regex():
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
# Note: Ensure this only uses attributes compatible with xgrammar
@pytest.fixture @pytest.fixture
def sample_json_schema(): def sample_json_schema():
return { return {
...@@ -44,9 +45,7 @@ def sample_json_schema(): ...@@ -44,9 +45,7 @@ def sample_json_schema():
"type": "array", "type": "array",
"items": { "items": {
"type": "string", "type": "string",
"maxLength": 10 }
},
"minItems": 3
}, },
"work_history": { "work_history": {
"type": "array", "type": "array",
...@@ -71,8 +70,9 @@ def sample_json_schema(): ...@@ -71,8 +70,9 @@ def sample_json_schema():
} }
# A schema unsupported by xgrammar
@pytest.fixture @pytest.fixture
def sample_complex_json_schema(): def unsupported_json_schema():
return { return {
"type": "object", "type": "object",
"properties": { "properties": {
...@@ -150,7 +150,19 @@ def sample_guided_choice(): ...@@ -150,7 +150,19 @@ def sample_guided_choice():
@pytest.fixture @pytest.fixture
def sample_sql_statements(): def sample_sql_ebnf():
return """
root ::= select_statement
select_statement ::= "SELECT" column "from" table "where" condition
column ::= "col_1" | "col_2"
table ::= "table_1" | "table_2"
condition ::= column "=" number
number ::= "1" | "2"
"""
@pytest.fixture
def sample_sql_lark():
return (""" return ("""
start: select_statement start: select_statement
select_statement: "SELECT" column "from" table "where" condition select_statement: "SELECT" column "from" table "where" condition
......
# SPDX-License-Identifier: Apache-2.0
import json
import jsonschema
import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_completion(monkeypatch, sample_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=unsupported_json_schema,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_lark,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_lark)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf_invalid(monkeypatch,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar="not a grammar",
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Regex guided decoding is not supported."):
llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
# Once regex is supported --
#assert outputs is not None
#for output in outputs:
# assert output is not None
# assert isinstance(output, RequestOutput)
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(generated_text)
# assert generated_text is not None
# assert re.fullmatch(sample_regex, generated_text) is not None
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
@pytest.fixture
def unsupported_string_schemas():
return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{
"type": "string",
"enum": ["active", "inactive", "pending"]
},
{
"type": "string",
"minLength": 1
},
{
"type": "string",
"maxLength": 100
},
{
"type": "string",
"format": "email"
},
]
@pytest.fixture
def unsupported_integer_schemas():
return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{
"type": "integer",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_number_schemas():
return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{
"type": "number",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_array_schemas():
return [
{
"type": "array",
"uniqueItems": True
},
{
"type": "array",
"contains": {
"type": "string"
}
},
{
"type": "array",
"minContains": 1
},
{
"type": "array",
"maxContains": 5
},
{
"type": "array",
"minItems": 1
},
{
"type": "array",
"maxItems": 10
},
]
@pytest.fixture
def unsupported_object_schemas():
return [
{
"type": "object",
"minProperties": 1
},
{
"type": "object",
"maxProperties": 5
},
{
"type": "object",
"propertyNames": {
"pattern": "^[a-z]+$"
}
},
{
"type": "object",
"patternProperties": {
"^S": {
"type": "string"
}
}
},
]
@pytest.fixture
def supported_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"status": {
"type": "string"
},
"scores": {
"type": "array",
"items": {
"type": "number"
}
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string"
},
"city": {
"type": "string"
}
}
}
}
}
@pytest.mark.parametrize("schema_type", [
"unsupported_string_schemas", "unsupported_integer_schemas",
"unsupported_number_schemas", "unsupported_array_schemas",
"unsupported_object_schemas"
])
def test_unsupported_json_features_by_type(schema_type, request):
schemas = request.getfixturevalue(schema_type)
for schema in schemas:
assert has_xgrammar_unsupported_json_features(
schema), f"Schema should be unsupported: {schema}"
def test_supported_json_features(supported_schema):
assert not has_xgrammar_unsupported_json_features(
supported_schema), "Schema should be supported"
...@@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: ...@@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
...@@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner): ...@@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids={req_id}, finished_req_ids={req_id},
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner): ...@@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
model_runner._update_states(scheduler_output) model_runner._update_states(scheduler_output)
...@@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner): ...@@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner): ...@@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
metadata_before = model_runner.input_batch.sampling_metadata metadata_before = model_runner.input_batch.sampling_metadata
...@@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner): ...@@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=0, num_common_prefix_blocks=0,
finished_req_ids=set(), finished_req_ids=set(),
free_encoder_input_ids=[], free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
) )
metadata_before = model_runner._update_states(scheduler_output) metadata_before = model_runner._update_states(scheduler_output)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import concurrent import concurrent
...@@ -8,6 +10,7 @@ import datetime ...@@ -8,6 +10,7 @@ import datetime
import enum import enum
import gc import gc
import getpass import getpass
import importlib
import importlib.metadata import importlib.metadata
import importlib.util import importlib.util
import inspect import inspect
...@@ -23,6 +26,7 @@ import tempfile ...@@ -23,6 +26,7 @@ import tempfile
import threading import threading
import time import time
import traceback import traceback
import types
import uuid import uuid
import warnings import warnings
import weakref import weakref
...@@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream: ...@@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream:
return _current_stream return _current_stream
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
"""Set up function tracing for the current thread, """Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable if enabled via the VLLM_TRACE_FUNCTION environment variable
""" """
...@@ -1977,7 +1981,7 @@ class MemorySnapshot: ...@@ -1977,7 +1981,7 @@ class MemorySnapshot:
self.non_torch_memory = self.cuda_memory - self.torch_memory self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time() self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
return MemorySnapshot( return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak, torch_peak=self.torch_peak - other.torch_peak,
cuda_memory=self.cuda_memory - other.cuda_memory, cuda_memory=self.cuda_memory - other.cuda_memory,
...@@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: ...@@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
type.__setattr__(cls, '__init__', wrapped_init) type.__setattr__(cls, '__init__', wrapped_init)
return cls return cls
class LazyLoader(types.ModuleType):
"""
LazyLoader module borrowed from Tensorflow
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
with a addition of "module caching".
Lazily import a module, mainly to avoid pulling in large dependencies.
Modules such as `xgrammar` might do additional side effects, so we
only want to use this when it is needed, delaying all eager effects
"""
def __init__(
self,
local_name: str,
parent_module_globals: dict[str, Any],
name: str,
):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module: types.ModuleType | None = None
super().__init__(str(name))
def _load(self) -> types.ModuleType:
# Import the target module and insert it into the parent's namespace
try:
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# The additional add to sys.modules
# ensures library is actually loaded.
sys.modules[self._local_name] = module
except ModuleNotFoundError as err:
raise err from None
# Update this object's dict so that if someone keeps a
# reference to the LazyLoader, lookups are efficient
# (__getattr__ is only called on lookups that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item: Any) -> Any:
if self._module is None:
self._module = self._load()
return getattr(self._module, item)
def __dir__(self) -> list[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time import time
from collections import deque from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
...@@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, ...@@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -32,12 +35,14 @@ class Scheduler: ...@@ -32,12 +35,14 @@ class Scheduler:
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
log_stats: bool, log_stats: bool,
structured_output_manager: StructuredOutputManager,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config self.lora_config = lora_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.log_stats = log_stats self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
...@@ -97,7 +102,7 @@ class Scheduler: ...@@ -97,7 +102,7 @@ class Scheduler:
self.encoder_cache_manager = EncoderCacheManager( self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size) cache_size=encoder_cache_size)
def schedule(self) -> "SchedulerOutput": def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm: # NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler. # There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and # Each request just has the num_computed_tokens and
...@@ -114,6 +119,14 @@ class Scheduler: ...@@ -114,6 +119,14 @@ class Scheduler:
scheduled_running_reqs: list[Request] = [] scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = [] preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {} req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
...@@ -184,6 +197,12 @@ class Scheduler: ...@@ -184,6 +197,12 @@ class Scheduler:
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id) self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks b.block_id for b in new_blocks
] ]
...@@ -219,6 +238,10 @@ class Scheduler: ...@@ -219,6 +238,10 @@ class Scheduler:
if req.lora_request and req.lora_request.lora_int_id > 0) if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras assert len(requested_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque()
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
...@@ -227,6 +250,16 @@ class Scheduler: ...@@ -227,6 +250,16 @@ class Scheduler:
request = self.waiting[0] request = self.waiting[0]
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
waiting_structured_output_req = self.waiting.popleft()
waiting_for_fsm.appendleft(
waiting_structured_output_req)
continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
...@@ -281,6 +314,10 @@ class Scheduler: ...@@ -281,6 +314,10 @@ class Scheduler:
break break
self.waiting.popleft() self.waiting.popleft()
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
req_index += 1
self.running.append(request) self.running.append(request)
self.scheduled_req_ids.add(request.request_id) self.scheduled_req_ids.add(request.request_id)
self.request_scheduled(request, scheduled_timestamp) self.request_scheduled(request, scheduled_timestamp)
...@@ -311,6 +348,10 @@ class Scheduler: ...@@ -311,6 +348,10 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm:
self.waiting.extendleft(waiting_for_fsm)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
...@@ -331,6 +372,11 @@ class Scheduler: ...@@ -331,6 +372,11 @@ class Scheduler:
self.kv_cache_manager.get_num_common_prefix_blocks( self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running))) any_request, len(self.running)))
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
len(self.running),
)
# Construct the scheduler output. # Construct the scheduler output.
new_reqs_data = [ new_reqs_data = [
NewRequestData.from_request(req, NewRequestData.from_request(req,
...@@ -369,6 +415,8 @@ class Scheduler: ...@@ -369,6 +415,8 @@ class Scheduler:
# the previous and the current steps. # the previous and the current steps.
finished_req_ids=self.finished_req_ids, finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
) )
self.finished_req_ids = set() self.finished_req_ids = set()
...@@ -381,7 +429,7 @@ class Scheduler: ...@@ -381,7 +429,7 @@ class Scheduler:
num_scheduled_spec_tokens: int, num_scheduled_spec_tokens: int,
new_block_ids: list[int], new_block_ids: list[int],
resumed_from_preemption: bool, resumed_from_preemption: bool,
) -> "CachedRequestData": ) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step. # them at each scheduling step.
num_computed_tokens = request.num_computed_tokens num_computed_tokens = request.num_computed_tokens
...@@ -474,8 +522,8 @@ class Scheduler: ...@@ -474,8 +522,8 @@ class Scheduler:
def update_from_output( def update_from_output(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: SchedulerOutput,
model_runner_output: "ModelRunnerOutput", model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs: ) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids spec_token_ids = model_runner_output.spec_token_ids
...@@ -565,6 +613,15 @@ class Scheduler: ...@@ -565,6 +613,15 @@ class Scheduler:
# the outer lists can be of length > 1. # the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1) new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id,
new_token_ids,
)
# Transmit partial if chunked prefill & prompt logprobs is enabled # Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None: if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request. # Add EngineCoreOutput for this Request.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange from vllm.multimodal.base import PlaceholderRange
...@@ -17,20 +22,20 @@ class NewRequestData: ...@@ -17,20 +22,20 @@ class NewRequestData:
req_id: str req_id: str
prompt_token_ids: list[int] prompt_token_ids: list[int]
prompt: Optional[str] prompt: Optional[str]
mm_inputs: list["MultiModalKwargs"] mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list["PlaceholderRange"] mm_positions: list[PlaceholderRange]
sampling_params: "SamplingParams" sampling_params: SamplingParams
block_ids: list[int] block_ids: list[int]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional["LoRARequest"] lora_request: Optional[LoRARequest]
@classmethod @classmethod
def from_request( def from_request(
cls, cls,
request: "Request", request: Request,
block_ids: list[int], block_ids: list[int],
) -> "NewRequestData": ) -> NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
...@@ -60,11 +65,11 @@ class CachedRequestData: ...@@ -60,11 +65,11 @@ class CachedRequestData:
@classmethod @classmethod
def from_request( def from_request(
cls, cls,
request: "Request", request: Request,
resumed_from_preemption: bool, resumed_from_preemption: bool,
new_token_ids: list[int], new_token_ids: list[int],
new_block_ids: list[int], new_block_ids: list[int],
) -> "CachedRequestData": ) -> CachedRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption, resumed_from_preemption=resumed_from_preemption,
...@@ -111,3 +116,9 @@ class SchedulerOutput: ...@@ -111,3 +116,9 @@ class SchedulerOutput:
# list of (req_id, encoder_input_index) tuples. # list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache. # Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]] free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
...@@ -72,9 +72,7 @@ class AsyncLLM(EngineClient): ...@@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
# Processor (converts Inputs --> EngineCoreRequests). # Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor( self.processor = Processor(
model_config=vllm_config.model_config, vllm_config=vllm_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry, input_registry=input_registry,
) )
...@@ -194,8 +192,8 @@ class AsyncLLM(EngineClient): ...@@ -194,8 +192,8 @@ class AsyncLLM(EngineClient):
* 3) Adding the Request to the Detokenizer. * 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process). * 4) Adding the Request to the EngineCore (separate process).
A separate output_handler loop runs in a background AsyncIO task, A separate output_handler loop runs in a background AsyncIO task,
pulling outputs from EngineCore and putting them into the pulling outputs from EngineCore and putting them into the
per-request AsyncStream. per-request AsyncStream.
The caller of generate() iterates the returned AsyncGenerator, The caller of generate() iterates the returned AsyncGenerator,
......
...@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor ...@@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -61,6 +62,8 @@ class EngineCore: ...@@ -61,6 +62,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler. # Setup scheduler.
self.scheduler = Scheduler( self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
...@@ -69,6 +72,7 @@ class EngineCore: ...@@ -69,6 +72,7 @@ class EngineCore:
lora_config=vllm_config.lora_config, lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config, speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats, log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
) )
# Setup MM Input Mapper. # Setup MM Input Mapper.
...@@ -131,6 +135,9 @@ class EngineCore: ...@@ -131,6 +135,9 @@ class EngineCore:
request.mm_inputs, request.mm_hashes) request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.populate_cache(req)
self.scheduler.add_request(req) self.scheduler.add_request(req)
...@@ -148,11 +155,24 @@ class EngineCore: ...@@ -148,11 +155,24 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests(): if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs( return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats()) outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output) output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore scheduler_output, output) # type: ignore
return engine_core_outputs return engine_core_outputs
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]: def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
......
...@@ -66,9 +66,7 @@ class LLMEngine: ...@@ -66,9 +66,7 @@ class LLMEngine:
self.tokenizer.ping() self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests) # Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(model_config=vllm_config.model_config, self.processor = Processor(vllm_config=vllm_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
input_registry=input_registry, input_registry=input_registry,
mm_registry=mm_registry) mm_registry=mm_registry)
......
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional, Union from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import is_encoder_decoder_inputs
...@@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams ...@@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.structured_output.utils import validate_structured_output_request
class Processor: class Processor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
tokenizer: BaseTokenizerGroup, tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
): ):
self.model_config = model_config self.vllm_config = vllm_config
self.cache_config = cache_config self.model_config = vllm_config.model_config
self.lora_config = lora_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.generation_config_fields = model_config.try_get_generation_config( self.generation_config_fields = (
) self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(model_config, self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
self.input_processor = input_registry.create_input_processor( self.input_processor = input_registry.create_input_processor(
model_config) self.model_config)
# Multi-modal (huggingface) input mapper # Multi-modal (huggingface) input mapper
self.mm_input_cache_client = MMInputCacheClient(model_config) self.mm_input_cache_client = MMInputCacheClient(self.model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \ self.use_hash = (
cache_config.enable_prefix_caching not self.model_config.disable_mm_preprocessor_cache) or \
self.cache_config.enable_prefix_caching
def _validate_logprobs( def _validate_logprobs(
self, self,
...@@ -80,6 +82,8 @@ class Processor: ...@@ -80,6 +82,8 @@ class Processor:
self, self,
params: SamplingParams, params: SamplingParams,
) -> None: ) -> None:
self._validate_structured_output(params)
if params.allowed_token_ids is None: if params.allowed_token_ids is None:
return return
if not params.allowed_token_ids: if not params.allowed_token_ids:
...@@ -125,6 +129,21 @@ class Processor: ...@@ -125,6 +129,21 @@ class Processor:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
validate_structured_output_request(params)
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment