Commit a130cf33 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx

parents a2d181be 82091b86
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "google/gemma-7b"
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
"Quote: Imagination is",
"Quote: Be yourself;",
"Quote: So many books,",
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_gemma_lora(gemma_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4)
expected_lora_output = [
"more important than knowledge.\nAuthor: Albert Einstein\n",
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
"so little time\nAuthor: Frank Zappa\n",
]
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i].startswith(expected_lora_output[i])
output2 = do_sample(llm, gemma_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i].startswith(expected_lora_output[i])
...@@ -279,7 +279,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: ...@@ -279,7 +279,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
256, 256,
org_num_embeddings=512) org_num_embeddings=512)
expanded_embedding.weight.data[:512, :] = embedding_data expanded_embedding.weight.data[:512, :] = embedding_data
# We need to deepcopy the embedding as it will be modifed # We need to deepcopy the embedding as it will be modified
# in place # in place
lora_embedding = VocabParallelEmbeddingWithLoRA( lora_embedding = VocabParallelEmbeddingWithLoRA(
deepcopy(expanded_embedding)) deepcopy(expanded_embedding))
......
...@@ -15,7 +15,7 @@ def do_sample(llm, lora_path: str, lora_id: int): ...@@ -15,7 +15,7 @@ def do_sample(llm, lora_path: str, lora_id: int):
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]"
] ]
sampling_params = vllm.SamplingParams(temperature=0, sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256, max_tokens=256,
...@@ -53,7 +53,7 @@ def test_llama_lora(sql_lora_files, tp_size): ...@@ -53,7 +53,7 @@ def test_llama_lora(sql_lora_files, tp_size):
"\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ",
" Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ",
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE",
] ]
expected_lora_output = [ expected_lora_output = [
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ",
......
...@@ -44,8 +44,8 @@ def _lora_ref_impl( ...@@ -44,8 +44,8 @@ def _lora_ref_impl(
H1 = H2 = [ H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, 5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
32256, 32512, 32768, 33024 24576, 32000, 32256, 32512, 32768, 33024
] ]
SEED = [0xabcdabcd987] SEED = [0xabcdabcd987]
......
import pytest import pytest
import vllm.engine.metrics
MODELS = [ MODELS = [
"facebook/opt-125m", "facebook/opt-125m",
...@@ -9,14 +8,17 @@ MODELS = [ ...@@ -9,14 +8,17 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
def test_metrics( def test_metric_counter_prompt_tokens(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
vllm_model = vllm_runner(model, dtype=dtype, disable_log_stats=False) vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
tokenizer = vllm_model.model.get_tokenizer() tokenizer = vllm_model.model.get_tokenizer()
prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts] prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts]
# This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding. # This test needs at least 2 prompts in a batch of different lengths to verify their token count is correct despite padding.
...@@ -26,8 +28,41 @@ def test_metrics( ...@@ -26,8 +28,41 @@ def test_metrics(
vllm_prompt_token_count = sum(prompt_token_counts) vllm_prompt_token_count = sum(prompt_token_counts)
_ = vllm_model.generate_greedy(example_prompts, max_tokens) _ = vllm_model.generate_greedy(example_prompts, max_tokens)
metric_count = vllm.engine.metrics.counter_prompt_tokens.get_value({}) stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get()
assert vllm_prompt_token_count == metric_count, ( assert vllm_prompt_token_count == metric_count, (
f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}" f"prompt token count: {vllm_prompt_token_count!r}\nmetric: {metric_count!r}"
) )
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
def test_metric_counter_generation_tokens(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer()
stat_logger = vllm_model.model.llm_engine.stat_logger
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get()
vllm_generation_count = 0
for i in range(len(example_prompts)):
vllm_output_ids, vllm_output_str = vllm_outputs[i]
prompt_ids = tokenizer.encode(example_prompts[i])
# vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens.
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
assert vllm_generation_count == metric_count, (
f"generation token count: {vllm_generation_count!r}\nmetric: {metric_count!r}"
)
"""Compare the outputs of a GPTQ model to a Marlin model.
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py --forked`.
"""
import pytest
import torch
from dataclasses import dataclass
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
marlin_not_supported = (
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
@dataclass
class ModelPair:
model_marlin: str
model_gptq: str
model_pairs = [
ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del marlin_model.model.llm_engine.driver_worker
del marlin_model
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del gptq_model.model.llm_engine.driver_worker
del gptq_model
# loop through the prompts
for prompt_idx in range(len(example_prompts)):
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
prompt_idx]
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
prompt_idx]
for idx, (gptq_output_id, marlin_output_id) in enumerate(
zip(gptq_output_ids, marlin_output_ids)):
# If sequence is not an exact match,
if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
)
# Break out since sequences will now diverge.
break
...@@ -19,6 +19,7 @@ MODELS = [ ...@@ -19,6 +19,7 @@ MODELS = [
"microsoft/phi-2", "microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-3b-4e1t",
"allenai/OLMo-1B", "allenai/OLMo-1B",
"bigcode/starcoder2-3b",
] ]
......
...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput ...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__ from vllm.version import __dcu_version__
__version__ = "0.3.2" __version__ = "0.3.3"
__all__ = [ __all__ = [
"LLM", "LLM",
......
...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig ...@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -155,15 +155,21 @@ class ModelConfig: ...@@ -155,15 +155,21 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"] supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq"] rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available. # Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None) hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None: if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower() hf_quant_method = str(hf_quant_config["quant_method"]).lower()
# If the GPTQ model is serialized in marlin format, use marlin.
if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):
hf_quant_method = "marlin"
if self.quantization is None: if self.quantization is None:
self.quantization = hf_quant_method self.quantization = hf_quant_method
elif self.quantization != hf_quant_method: elif self.quantization != hf_quant_method:
...@@ -183,9 +189,11 @@ class ModelConfig: ...@@ -183,9 +189,11 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not supported " f"{self.quantization} quantization is currently not supported "
f"in ROCm.") f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully " if self.quantization != "marlin":
"optimized yet. The speed can be slower than " logger.warning(
"non-quantized models.") f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None: if self.max_context_len_to_capture is None:
...@@ -308,6 +316,10 @@ class CacheConfig: ...@@ -308,6 +316,10 @@ class CacheConfig:
self.num_gpu_blocks = None self.num_gpu_blocks = None
self.num_cpu_blocks = None self.num_cpu_blocks = None
def metrics_info(self):
# convert cache_config to dict(key: str, value:str) for prometheus metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0: if self.gpu_memory_utilization > 1.0:
raise ValueError( raise ValueError(
...@@ -319,7 +331,7 @@ class CacheConfig: ...@@ -319,7 +331,7 @@ class CacheConfig:
pass pass
elif self.cache_dtype == "fp8_e5m2": elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version() nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version < Version("11.8"): if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
raise ValueError( raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8." "FP8 is not supported when cuda version is lower than 11.8."
) )
...@@ -380,13 +392,21 @@ class ParallelConfig: ...@@ -380,13 +392,21 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1: # Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True self.worker_use_ray = True
self._verify_args() self._verify_args()
...@@ -465,8 +485,29 @@ class SchedulerConfig: ...@@ -465,8 +485,29 @@ class SchedulerConfig:
class DeviceConfig: class DeviceConfig:
def __init__(self, device: str = "cuda") -> None: def __init__(self, device: str = "auto") -> None:
self.device = torch.device(device) if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass @dataclass
......
...@@ -178,7 +178,7 @@ class BlockSpaceManager: ...@@ -178,7 +178,7 @@ class BlockSpaceManager:
if len(block_table) < len(logical_blocks): if len(block_table) < len(logical_blocks):
if (self.block_sliding_window if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window): and len(block_table) >= self.block_sliding_window):
# re-use a block # reuse a block
block_table.append(block_table[len(block_table) % block_table.append(block_table[len(block_table) %
self.block_sliding_window]) self.block_sliding_window])
else: else:
......
...@@ -158,7 +158,7 @@ class Scheduler: ...@@ -158,7 +158,7 @@ class Scheduler:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
......
...@@ -44,7 +44,7 @@ class EngineArgs: ...@@ -44,7 +44,7 @@ class EngineArgs:
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
device: str = 'cuda' device: str = 'auto'
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -171,7 +171,7 @@ class EngineArgs: ...@@ -171,7 +171,7 @@ class EngineArgs:
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32, 128],
help='token block size') help='token block size')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
...@@ -264,13 +264,11 @@ class EngineArgs: ...@@ -264,13 +264,11 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. ' help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. ' 'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.')) 'Defaults to max_num_seqs.'))
parser.add_argument( parser.add_argument("--device",
"--device", type=str,
type=str, default=EngineArgs.device,
default=EngineArgs.device, choices=["auto", "cuda", "neuron"],
choices=["cuda"], help='Device type for vLLM execution.')
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
return parser return parser
@classmethod @classmethod
......
...@@ -333,6 +333,9 @@ class AsyncLLMEngine: ...@@ -333,6 +333,9 @@ class AsyncLLMEngine:
return (self.background_loop is not None return (self.background_loop is not None
and not self.background_loop.done()) and not self.background_loop.done())
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
if self.is_running: if self.is_running:
......
...@@ -3,6 +3,7 @@ from collections import defaultdict ...@@ -3,6 +3,7 @@ from collections import defaultdict
import os import os
import time import time
import pickle import pickle
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union) Union)
...@@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, ...@@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup) TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
get_open_port, get_distributed_init_method)
if ray: if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
...@@ -31,6 +33,12 @@ if TYPE_CHECKING: ...@@ -31,6 +33,12 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
...@@ -128,7 +136,9 @@ class LLMEngine: ...@@ -128,7 +136,9 @@ class LLMEngine:
# Metric Logging. # Metric Logging.
if self.log_stats: if self.log_stats:
self.stat_logger = StatLogger( self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC) local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
...@@ -137,10 +147,17 @@ class LLMEngine: ...@@ -137,10 +147,17 @@ class LLMEngine:
def get_tokenizer_for_seq(self, sequence: Sequence): def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers(self): def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker Worker = self._dispatch_worker()
assert self.parallel_config.world_size == 1, ( assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
...@@ -242,7 +259,7 @@ class LLMEngine: ...@@ -242,7 +259,7 @@ class LLMEngine:
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker Worker = self._dispatch_worker()
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
...@@ -283,7 +300,10 @@ class LLMEngine: ...@@ -283,7 +300,10 @@ class LLMEngine:
is_driver_worker=True, is_driver_worker=True,
) )
self._run_workers("init_model", cupy_port=get_open_port()) # don't use cupy for eager mode
self._run_workers("init_model",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers( self._run_workers(
"load_model", "load_model",
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
...@@ -464,8 +484,9 @@ class LLMEngine: ...@@ -464,8 +484,9 @@ class LLMEngine:
prompt_token_ids[:prefix_pos], lora_request.lora_int_id prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler # Defensive copy of SamplingParams, which are used by the sampler,
sampling_params = copy.deepcopy(sampling_params) # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
...@@ -872,6 +893,9 @@ class LLMEngine: ...@@ -872,6 +893,9 @@ class LLMEngine:
num_prompt_tokens = sum( num_prompt_tokens = sum(
len(seq_group.prompt_token_ids) len(seq_group.prompt_token_ids)
for seq_group in scheduler_outputs.scheduled_seq_groups) for seq_group in scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum(
seq_group.num_seqs()
for seq_group in scheduler_outputs.scheduled_seq_groups)
else: else:
num_generation_tokens = scheduler_outputs.num_batched_tokens num_generation_tokens = scheduler_outputs.num_batched_tokens
...@@ -956,7 +980,10 @@ class LLMEngine: ...@@ -956,7 +980,10 @@ class LLMEngine:
def _finalize_sequence(self, seq: Sequence, def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams, sampling_params: SamplingParams,
stop_string: str) -> None: stop_string: str) -> None:
if not sampling_params.include_stop_str_in_output and stop_string: if sampling_params.include_stop_str_in_output:
return
if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is # Truncate the output text so that the stop string is
# not included in the output. # not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)] seq.output_text = seq.output_text[:-len(stop_string)]
......
from vllm.logger import init_logger from vllm.logger import init_logger
from aioprometheus import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
import time import time
import numpy as np import numpy as np
from typing import List from typing import Dict, List
from dataclasses import dataclass from dataclasses import dataclass
logger = init_logger(__name__) logger = init_logger(__name__)
labels = {} disable_created_metrics()
def add_global_metrics_labels(**kwargs):
labels.update(kwargs)
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions. # to extract the metrics definitions.
# begin-metrics-definitions # begin-metrics-definitions
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s", class Metrics:
"Average prefill throughput in tokens/s.")
gauge_avg_generation_throughput = Gauge( def __init__(self, labelnames: List[str]):
"vllm:avg_generation_throughput_toks_per_s", # Unregister any existing vLLM collectors
"Average generation throughput in tokens/s.") for collector in list(REGISTRY._collector_to_names):
counter_prompt_tokens = Counter("vllm:prompt_tokens_total", if hasattr(collector, "_name") and "vllm" in collector._name:
"Number of prefill tokens processed.") REGISTRY.unregister(collector)
counter_generation_tokens = Counter("vllm:generation_tokens_total",
"Number of generation tokens processed.") self.info_cache_config = Info(
name='vllm:cache_config',
gauge_scheduler_running = Gauge( documentation='information of cache_config')
"vllm:num_requests_running",
"Number of requests currently running on GPU.") # System stats
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped", self.gauge_scheduler_running = Gauge(
"Number of requests swapped to CPU.") name="vllm:num_requests_running",
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting", documentation="Number of requests currently running on GPU.",
"Number of requests waiting to be processed.") labelnames=labelnames)
self.gauge_scheduler_swapped = Gauge(
gauge_gpu_cache_usage = Gauge( name="vllm:num_requests_swapped",
"vllm:gpu_cache_usage_perc", documentation="Number of requests swapped to CPU.",
"GPU KV-cache usage. 1 means 100 percent usage.") labelnames=labelnames)
gauge_cpu_cache_usage = Gauge( self.gauge_scheduler_waiting = Gauge(
"vllm:cpu_cache_usage_perc", name="vllm:num_requests_waiting",
"CPU KV-cache usage. 1 means 100 percent usage.") documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
histogram_time_to_first_token = Histogram( self.gauge_gpu_cache_usage = Gauge(
"vllm:time_to_first_token_seconds", name="vllm:gpu_cache_usage_perc",
"Histogram of time to first token in seconds.", documentation="GPU KV-cache usage. 1 means 100 percent usage.",
buckets=[ labelnames=labelnames)
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, self.gauge_cpu_cache_usage = Gauge(
2.5, 5.0, 7.5, 10.0 name="vllm:cpu_cache_usage_perc",
]) documentation="CPU KV-cache usage. 1 means 100 percent usage.",
histogram_time_per_output_tokens = Histogram( labelnames=labelnames)
"vllm:time_per_output_token_seconds",
"Histogram of time per output token in seconds.", # Raw stats from last model iteration
buckets=[ self.counter_prompt_tokens = Counter(
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 name="vllm:prompt_tokens_total",
]) documentation="Number of prefill tokens processed.",
histogram_e2e_request_latency = Histogram( labelnames=labelnames)
"vllm:e2e_request_latency_seconds", self.counter_generation_tokens = Counter(
"Histogram of end to end request latency in seconds.", name="vllm:generation_tokens_total",
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
self.histogram_e2e_request_latency = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Legacy metrics
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
# end-metrics-definitions # end-metrics-definitions
...@@ -87,7 +119,7 @@ class Stats: ...@@ -87,7 +119,7 @@ class Stats:
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float) -> None: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally. # Metadata for logging locally.
self.last_local_log = time.monotonic() self.last_local_log = time.monotonic()
self.local_interval = local_interval self.local_interval = local_interval
...@@ -96,6 +128,14 @@ class StatLogger: ...@@ -96,6 +128,14 @@ class StatLogger:
self.num_prompt_tokens: List[int] = [] self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = [] self.num_generation_tokens: List[int] = []
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float: def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log)) return float(np.sum(tracked_stats) / (now - self.last_local_log))
...@@ -105,23 +145,33 @@ class StatLogger: ...@@ -105,23 +145,33 @@ class StatLogger:
def _log_prometheus(self, stats: Stats) -> None: def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges. # Set system stat gauges.
gauge_scheduler_running.set(labels, stats.num_running) self.metrics.gauge_scheduler_running.labels(**self.labels).set(
gauge_scheduler_swapped.set(labels, stats.num_swapped) stats.num_running)
gauge_scheduler_waiting.set(labels, stats.num_waiting) self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
gauge_gpu_cache_usage.set(labels, stats.gpu_cache_usage) stats.num_swapped)
gauge_cpu_cache_usage.set(labels, stats.cpu_cache_usage) self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
stats.num_waiting)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
stats.gpu_cache_usage)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
stats.cpu_cache_usage)
# Add to token counters. # Add to token counters.
counter_prompt_tokens.add(labels, stats.num_prompt_tokens) self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
counter_generation_tokens.add(labels, stats.num_generation_tokens) stats.num_prompt_tokens)
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
stats.num_generation_tokens)
# Observe request level latencies in histograms. # Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens: for ttft in stats.time_to_first_tokens:
histogram_time_to_first_token.observe(labels, ttft) self.metrics.histogram_time_to_first_token.labels(
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens: for tpot in stats.time_per_output_tokens:
histogram_time_per_output_tokens.observe(labels, tpot) self.metrics.histogram_time_per_output_token.labels(
**self.labels).observe(tpot)
for e2e in stats.time_e2e_requests: for e2e in stats.time_e2e_requests:
histogram_e2e_request_latency.observe(labels, e2e) self.metrics.histogram_e2e_request_latency.labels(
**self.labels).observe(e2e)
def _log_prometheus_interval(self, prompt_throughput: float, def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None: generation_throughput: float) -> None:
...@@ -130,8 +180,10 @@ class StatLogger: ...@@ -130,8 +180,10 @@ class StatLogger:
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens # Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side. # Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 # See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
gauge_avg_prompt_throughput.set(labels, prompt_throughput) self.metrics.gauge_avg_prompt_throughput.labels(
gauge_avg_generation_throughput.set(labels, generation_throughput) **self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None: def log(self, stats: Stats) -> None:
"""Called by LLMEngine. """Called by LLMEngine.
......
...@@ -6,8 +6,7 @@ import os ...@@ -6,8 +6,7 @@ import os
import importlib import importlib
import inspect import inspect
from aioprometheus import MetricsMiddleware from prometheus_client import make_asgi_app
from aioprometheus.asgi.starlette import metrics
import fastapi import fastapi
import uvicorn import uvicorn
from http import HTTPStatus from http import HTTPStatus
...@@ -18,7 +17,6 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response ...@@ -18,7 +17,6 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import add_global_metrics_labels
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
...@@ -141,8 +139,9 @@ def parse_args(): ...@@ -141,8 +139,9 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics # Add prometheus asgi middleware to route /metrics requests
app.add_route("/metrics", metrics) # Exposes HTTP metrics metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
...@@ -242,9 +241,6 @@ if __name__ == "__main__": ...@@ -242,9 +241,6 @@ if __name__ == "__main__":
openai_serving_completion = OpenAIServingCompletion( openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules) engine, served_model, args.lora_modules)
# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
......
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
import time import time
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
import torch
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
object: str = "error" object: str = "error"
...@@ -55,7 +57,7 @@ class UsageInfo(BaseModel): ...@@ -55,7 +57,7 @@ class UsageInfo(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: Union[str, List[Dict[str, str]]] messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0 top_p: Optional[float] = 1.0
n: Optional[int] = 1 n: Optional[int] = 1
...@@ -63,6 +65,8 @@ class ChatCompletionRequest(BaseModel): ...@@ -63,6 +65,8 @@ class ChatCompletionRequest(BaseModel):
seed: Optional[int] = None seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None logit_bias: Optional[Dict[str, float]] = None
...@@ -72,6 +76,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -72,6 +76,7 @@ class ChatCompletionRequest(BaseModel):
top_k: Optional[int] = -1 top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
...@@ -81,8 +86,28 @@ class ChatCompletionRequest(BaseModel): ...@@ -81,8 +86,28 @@ class ChatCompletionRequest(BaseModel):
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0 length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self) -> SamplingParams: def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams( return SamplingParams(
n=self.n, n=self.n,
presence_penalty=self.presence_penalty, presence_penalty=self.presence_penalty,
...@@ -95,16 +120,34 @@ class ChatCompletionRequest(BaseModel): ...@@ -95,16 +120,34 @@ class ChatCompletionRequest(BaseModel):
stop=self.stop, stop=self.stop,
stop_token_ids=self.stop_token_ids, stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens, max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of, best_of=self.best_of,
top_k=self.top_k, top_k=self.top_k,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors,
) )
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
model: str model: str
...@@ -129,6 +172,7 @@ class CompletionRequest(BaseModel): ...@@ -129,6 +172,7 @@ class CompletionRequest(BaseModel):
top_k: Optional[int] = -1 top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True
...@@ -136,10 +180,27 @@ class CompletionRequest(BaseModel): ...@@ -136,10 +180,27 @@ class CompletionRequest(BaseModel):
min_p: Optional[float] = 0.0 min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0 length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self): def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams( return SamplingParams(
n=self.n, n=self.n,
best_of=self.best_of, best_of=self.best_of,
...@@ -157,13 +218,29 @@ class CompletionRequest(BaseModel): ...@@ -157,13 +218,29 @@ class CompletionRequest(BaseModel):
max_tokens=self.max_tokens if not echo_without_generation else 1, max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs, logprobs=self.logprobs,
use_beam_search=self.use_beam_search, use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None, prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens), spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors,
) )
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(BaseModel): class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list) text_offset: List[int] = Field(default_factory=list)
...@@ -212,6 +289,7 @@ class ChatMessage(BaseModel): ...@@ -212,6 +289,7 @@ class ChatMessage(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
...@@ -232,6 +310,7 @@ class DeltaMessage(BaseModel): ...@@ -232,6 +310,7 @@ class DeltaMessage(BaseModel):
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Literal["stop", "length"]] = None
......
...@@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -12,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (
UsageInfo) UsageInfo)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -39,19 +40,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -39,19 +40,13 @@ class OpenAIServingChat(OpenAIServing):
See https://platform.openai.com/docs/api-reference/chat/create See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API. for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features: NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves) - function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return self.create_error_response(
"logit_bias is not currently supported")
try: try:
prompt = self.tokenizer.apply_chat_template( prompt = self.tokenizer.apply_chat_template(
conversation=request.messages, conversation=request.messages,
...@@ -68,6 +63,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -68,6 +63,14 @@ class OpenAIServingChat(OpenAIServing):
prompt=prompt) prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -86,7 +89,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -86,7 +89,7 @@ class OpenAIServingChat(OpenAIServing):
if request.add_generation_prompt: if request.add_generation_prompt:
return self.response_role return self.response_role
else: else:
return request.messages[-1].role return request.messages[-1]["role"]
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, request: ChatCompletionRequest, self, request: ChatCompletionRequest,
...@@ -101,7 +104,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -101,7 +104,10 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None) index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id, chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
...@@ -118,6 +124,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -118,6 +124,7 @@ class OpenAIServingChat(OpenAIServing):
"content") and request.messages[-1].get( "content") and request.messages[-1].get(
"role") == role: "role") == role:
last_msg_content = request.messages[-1]["content"] last_msg_content = request.messages[-1]["content"]
if last_msg_content: if last_msg_content:
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
...@@ -129,6 +136,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -129,6 +136,7 @@ class OpenAIServingChat(OpenAIServing):
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
logprobs=None,
model=model_name) model=model_name)
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
...@@ -145,15 +153,29 @@ class OpenAIServingChat(OpenAIServing): ...@@ -145,15 +153,29 @@ class OpenAIServingChat(OpenAIServing):
if finish_reason_sent[i]: if finish_reason_sent[i]:
continue continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None: if output.finish_reason is None:
# Send token-by-token response for each request.n # Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(content=delta_text), delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None) finish_reason=None)
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
...@@ -174,6 +196,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -174,6 +196,7 @@ class OpenAIServingChat(OpenAIServing):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(content=delta_text), delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason) finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
...@@ -208,11 +231,25 @@ class OpenAIServingChat(OpenAIServing): ...@@ -208,11 +231,25 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None assert final_res is not None
choices = [] choices = []
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=output.index, index=output.index,
message=ChatMessage(role=role, content=output.text), message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason, finish_reason=output.finish_reason,
) )
choices.append(choice_data) choices.append(choice_data)
......
...@@ -5,7 +5,7 @@ from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict ...@@ -5,7 +5,7 @@ from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from .protocol import ( from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseChoice, CompletionResponseChoice,
...@@ -16,6 +16,7 @@ from .protocol import ( ...@@ -16,6 +16,7 @@ from .protocol import (
) )
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -95,7 +96,7 @@ async def completion_stream_generator( ...@@ -95,7 +96,7 @@ async def completion_stream_generator(
logprobs=logprobs, logprobs=logprobs,
finish_reason=finish_reason, finish_reason=finish_reason,
) )
]).model_dump_json(exclude_unset=True) ]).model_dump_json()
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage if output.finish_reason is not None: # return final usage
...@@ -120,7 +121,7 @@ async def completion_stream_generator( ...@@ -120,7 +121,7 @@ async def completion_stream_generator(
) )
], ],
usage=final_usage, usage=final_usage,
).model_dump_json(exclude_unset=True) ).model_dump_json()
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -264,10 +265,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -264,10 +265,9 @@ class OpenAIServingCompletion(OpenAIServing):
See https://platform.openai.com/docs/api-reference/completions/create See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API. for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features: NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support - suffix (the language models we currently support do not support
suffix) suffix)
- logit_bias (to be supported by vLLM engine)
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
...@@ -277,9 +277,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -277,9 +277,6 @@ class OpenAIServingCompletion(OpenAIServing):
if request.suffix is not None: if request.suffix is not None:
return self.create_error_response( return self.create_error_response(
"suffix is not currently supported") "suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
return self.create_error_response(
"logit_bias is not currently supported")
model_name = request.model model_name = request.model
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
...@@ -290,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -290,6 +287,14 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt) prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
...@@ -301,7 +306,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -301,7 +306,7 @@ class OpenAIServingCompletion(OpenAIServing):
request, prompt=prompt) request, prompt=prompt)
generators.append( generators.append(
self.engine.generate(None, self.engine.generate(prompt,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids, prompt_token_ids=input_ids,
......
...@@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA): ...@@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states
@property @property
def vocab_size(self): def vocab_size(self):
return self.base_layer.vocab_size return self.base_layer.vocab_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment