Commit a130cf33 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx

parents a2d181be 82091b86
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "google/gemma-7b"
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
"Quote: Imagination is",
"Quote: Be yourself;",
"Quote: So many books,",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time\nAuthor: Frank Zappa\n",
]
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, gemma_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])
......@@ -279,7 +279,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
256,
org_num_embeddings=512)
expanded_embedding.weight.data[:512, :] = embedding_data
# We need to deepcopy the embedding as it will be modifed
# We need to deepcopy the embedding as it will be modified
# in place
lora_embedding = VocabParallelEmbeddingWithLoRA(
deepcopy(expanded_embedding))
......
......@@ -15,7 +15,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]"
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]"
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
......@@ -53,7 +53,7 @@ def test_llama_lora(sql_lora_files, tp_size):
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
]
expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
......
......@@ -44,8 +44,8 @@ def _lora_ref_impl(
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000,
32256, 32512, 32768, 33024
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
24576, 32000, 32256, 32512, 32768, 33024
]
SEED = [0xabcdabcd987]
......
import pytest
import vllm.engine.metrics
MODELS = [
"facebook/opt-125m",
......@@ -9,14 +8,17 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
def test_metrics(
def test_metric_counter_prompt_tokens(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False)
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
tokenizer = vllm_model.model.get_tokenizer()
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
......@@ -26,8 +28,41 @@ def test_metrics(
vllm_prompt_token_count = sum(prompt_token_counts)
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({})
stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get()
assert vllm_prompt_token_count == metric_count, (
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
def test_metric_counter_generation_tokens(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer()
stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get()
vllm_generation_count = 0
for i in range(len(example_prompts)):
vllm_output_ids, vllm_output_str = vllm_outputs[i]
prompt_ids = tokenizer.encode(example_prompts[i])
# vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens.
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
assert vllm_generation_count == metric_count, (
f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}"
)
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py --forked`.
"""
import pytest
import torch
from dataclasses import dataclass
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
@dataclass
class ModelPair:
model_marlin: str
model_gptq: str
model_pairs = [
ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del marlin_model.model.llm_engine.driver_worker
del marlin_model
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del gptq_model.model.llm_engine.driver_worker
del gptq_model
# loop through the prompts
for prompt_idx in range(len(example_prompts)):
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
prompt_idx]
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
prompt_idx]
for idx, (gptq_output_id, marlin_output_id) in enumerate(
zip(gptq_output_ids, marlin_output_ids)):
# If sequence is not an exact match,
if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
# Break out since sequences will now diverge.
break
......@@ -19,6 +19,7 @@ MODELS = [
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
"allenai/OLMo-1B",
"bigcode/starcoder2-3b",
]
......
......@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__
__version__ = "0.3.2"
__version__ = "0.3.3"
__all__ = [
"LLM",
......
......@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
logger = init_logger(__name__)
......@@ -155,15 +155,21 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
# If the GPTQ model is serialized in marlin format, use marlin.
if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):
hf_quant_method = "marlin"
if self.quantization is None:
self.quantization = hf_quant_method
elif self.quantization != hf_quant_method:
......@@ -183,9 +189,11 @@ class ModelConfig:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
if self.quantization != "marlin":
logger.warning(
f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
......@@ -308,6 +316,10 @@ class CacheConfig:
self.num_gpu_blocks = None
self.num_cpu_blocks = None
def metrics_info(self):
# convert cache_config to dict(key: str, value:str) for prometheus metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0:
raise ValueError(
......@@ -319,7 +331,7 @@ class CacheConfig:
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"):
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8."
)
......@@ -380,13 +392,21 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True
self._verify_args()
......@@ -465,8 +485,29 @@ class SchedulerConfig:
class DeviceConfig:
def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass
......
......@@ -178,7 +178,7 @@ class BlockSpaceManager:
if len(block_table) < len(logical_blocks):
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# re-use a block
# reuse a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
else:
......
......@@ -158,7 +158,7 @@ class Scheduler:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
......
......@@ -44,7 +44,7 @@ class EngineArgs:
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
device: str = 'auto'
def __post_init__(self):
if self.tokenizer is None:
......@@ -171,7 +171,7 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 128],
help='token block size')
parser.add_argument('--seed',
type=int,
......@@ -264,13 +264,11 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["cuda"],
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
return parser
@classmethod
......
......@@ -333,6 +333,9 @@ class AsyncLLMEngine:
return (self.background_loop is not None
and not self.background_loop.done())
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:
......
......@@ -3,6 +3,7 @@ from collections import defaultdict
import os
import time
import pickle
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
......@@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
get_open_port, get_distributed_init_method)
if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
......@@ -31,6 +33,12 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
......@@ -128,7 +136,9 @@ class LLMEngine:
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC)
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)
self.forward_dag = None
if USE_RAY_COMPILED_DAG:
......@@ -137,10 +147,17 @@ class LLMEngine:
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
......@@ -242,7 +259,7 @@ class LLMEngine:
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()
# Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config)
......@@ -283,7 +300,10 @@ class LLMEngine:
is_driver_worker=True,
)
self._run_workers("init_model", cupy_port=get_open_port())
# don't use cupy for eager mode
self._run_workers("init_model",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
......@@ -464,8 +484,9 @@ class LLMEngine:
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler
sampling_params = copy.deepcopy(sampling_params)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
......@@ -872,6 +893,9 @@ class LLMEngine:
num_prompt_tokens = sum(
len(seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum(
seq_group.num_seqs()
for seq_group in scheduler_outputs.scheduled_seq_groups)
else:
num_generation_tokens = scheduler_outputs.num_batched_tokens
......@@ -956,7 +980,10 @@ class LLMEngine:
def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams,
stop_string: str) -> None:
if not sampling_params.include_stop_str_in_output and stop_string:
if sampling_params.include_stop_str_in_output:
return
if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
......
from vllm.logger import init_logger
from aioprometheus import Counter, Gauge, Histogram
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
import time
import numpy as np
from typing import List
from typing import Dict, List
from dataclasses import dataclass
logger = init_logger(__name__)
labels = {}
def add_global_metrics_labels(**kwargs):
labels.update(kwargs)
disable_created_metrics()
# The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions.
# begin-metrics-definitions
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s",
"Average prefill throughput in tokens/s.")
gauge_avg_generation_throughput = Gauge(
"vllm:avg_generation_throughput_toks_per_s",
"Average generation throughput in tokens/s.")
counter_prompt_tokens = Counter("vllm:prompt_tokens_total",
"Number of prefill tokens processed.")
counter_generation_tokens = Counter("vllm:generation_tokens_total",
"Number of generation tokens processed.")
gauge_scheduler_running = Gauge(
"vllm:num_requests_running",
"Number of requests currently running on GPU.")
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped",
"Number of requests swapped to CPU.")
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
"Number of requests waiting to be processed.")
gauge_gpu_cache_usage = Gauge(
"vllm:gpu_cache_usage_perc",
"GPU KV-cache usage. 1 means 100 percent usage.")
gauge_cpu_cache_usage = Gauge(
"vllm:cpu_cache_usage_perc",
"CPU KV-cache usage. 1 means 100 percent usage.")
histogram_time_to_first_token = Histogram(
"vllm:time_to_first_token_seconds",
"Histogram of time to first token in seconds.",
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0,
2.5, 5.0, 7.5, 10.0
])
histogram_time_per_output_tokens = Histogram(
"vllm:time_per_output_token_seconds",
"Histogram of time per output token in seconds.",
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5
])
histogram_e2e_request_latency = Histogram(
"vllm:e2e_request_latency_seconds",
"Histogram of end to end request latency in seconds.",
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
class Metrics:
def __init__(self, labelnames: List[str]):
# Unregister any existing vLLM collectors
for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
REGISTRY.unregister(collector)
self.info_cache_config = Info(
name='vllm:cache_config',
documentation='information of cache_config')
# System stats
self.gauge_scheduler_running = Gauge(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Raw stats from last model iteration
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
self.histogram_e2e_request_latency = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Legacy metrics
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
# end-metrics-definitions
......@@ -87,7 +119,7 @@ class Stats:
class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float) -> None:
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally.
self.last_local_log = time.monotonic()
self.local_interval = local_interval
......@@ -96,6 +128,14 @@ class StatLogger:
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))
......@@ -105,23 +145,33 @@ class StatLogger:
def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges.
gauge_scheduler_running.set(labels, stats.num_running)
gauge_scheduler_swapped.set(labels, stats.num_swapped)
gauge_scheduler_waiting.set(labels, stats.num_waiting)
gauge_gpu_cache_usage.set(labels, stats.gpu_cache_usage)
gauge_cpu_cache_usage.set(labels, stats.cpu_cache_usage)
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
stats.num_running)
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
stats.num_swapped)
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
stats.num_waiting)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
stats.gpu_cache_usage)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
stats.cpu_cache_usage)
# Add to token counters.
counter_prompt_tokens.add(labels, stats.num_prompt_tokens)
counter_generation_tokens.add(labels, stats.num_generation_tokens)
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
stats.num_prompt_tokens)
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
stats.num_generation_tokens)
# Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens:
histogram_time_to_first_token.observe(labels, ttft)
self.metrics.histogram_time_to_first_token.labels(
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens:
histogram_time_per_output_tokens.observe(labels, tpot)
self.metrics.histogram_time_per_output_token.labels(
**self.labels).observe(tpot)
for e2e in stats.time_e2e_requests:
histogram_e2e_request_latency.observe(labels, e2e)
self.metrics.histogram_e2e_request_latency.labels(
**self.labels).observe(e2e)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
......@@ -130,8 +180,10 @@ class StatLogger:
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
gauge_avg_prompt_throughput.set(labels, prompt_throughput)
gauge_avg_generation_throughput.set(labels, generation_throughput)
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
......
......@@ -6,8 +6,7 @@ import os
import importlib
import inspect
from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
from prometheus_client import make_asgi_app
import fastapi
import uvicorn
from http import HTTPStatus
......@@ -18,7 +17,6 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import add_global_metrics_labels
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
......@@ -141,8 +139,9 @@ def parse_args():
return parser.parse_args()
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
app.add_route("/metrics", metrics) # Exposes HTTP metrics
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.exception_handler(RequestValidationError)
......@@ -242,9 +241,6 @@ if __name__ == "__main__":
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
......
......@@ -3,11 +3,13 @@
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
import torch
class ErrorResponse(BaseModel):
object: str = "error"
......@@ -55,7 +57,7 @@ class UsageInfo(BaseModel):
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
......@@ -63,6 +65,8 @@ class ChatCompletionRequest(BaseModel):
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
......@@ -72,6 +76,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
......@@ -81,8 +86,28 @@ class ChatCompletionRequest(BaseModel):
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
......@@ -95,16 +120,34 @@ class ChatCompletionRequest(BaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(BaseModel):
model: str
......@@ -129,6 +172,7 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
......@@ -136,10 +180,27 @@ class CompletionRequest(BaseModel):
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
best_of=self.best_of,
......@@ -157,13 +218,29 @@ class CompletionRequest(BaseModel):
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
......@@ -212,6 +289,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
......@@ -232,6 +310,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
......
......@@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
......@@ -39,19 +40,13 @@ class OpenAIServingChat(OpenAIServing):
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")
try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
......@@ -68,6 +63,14 @@ class OpenAIServingChat(OpenAIServing):
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
except ValueError as e:
return self.create_error_response(str(e))
......@@ -86,7 +89,7 @@ class OpenAIServingChat(OpenAIServing):
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1].role
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
......@@ -101,7 +104,10 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
......@@ -118,6 +124,7 @@ class OpenAIServingChat(OpenAIServing):
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
......@@ -129,6 +136,7 @@ class OpenAIServingChat(OpenAIServing):
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
......@@ -145,15 +153,29 @@ class OpenAIServingChat(OpenAIServing):
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
......@@ -174,6 +196,7 @@ class OpenAIServingChat(OpenAIServing):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
......@@ -208,11 +231,25 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None
choices = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
......
......@@ -5,7 +5,7 @@ from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from .protocol import (
from vllm.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
......@@ -16,6 +16,7 @@ from .protocol import (
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
......@@ -95,7 +96,7 @@ async def completion_stream_generator(
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json(exclude_unset=True)
]).model_dump_json()
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
......@@ -120,7 +121,7 @@ async def completion_stream_generator(
)
],
usage=final_usage,
).model_dump_json(exclude_unset=True)
).model_dump_json()
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
......@@ -264,10 +265,9 @@ class OpenAIServingCompletion(OpenAIServing):
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
......@@ -277,9 +277,6 @@ class OpenAIServingCompletion(OpenAIServing):
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
......@@ -290,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing):
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
......@@ -301,7 +306,7 @@ class OpenAIServingCompletion(OpenAIServing):
request, prompt=prompt)
generators.append(
self.engine.generate(None,
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
......
......@@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self.dtype = dtype
self.device = device
@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states
@property
def vocab_size(self):
return self.base_layer.vocab_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment