"vscode:/vscode.git/clone" did not exist on "d57b5bd27011b87920ec1f64d8959054f64c3423"
Commit e00b0a19 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.3

parents ead94d93 3f1166ab
"""Tests for the SamplingParams class.
"""
from vllm import SamplingParams
def test_max_tokens_none():
"""max_tokens=None should be allowed"""
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
if __name__ == "__main__":
import pytest
pytest.main([__file__])
import torch
import random
import pytest
from unittest.mock import MagicMock
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker
from vllm.worker.worker import Worker
from vllm.model_executor.utils import set_random_seed
from .utils import (create_execute_model_data, create_worker,
create_seq_group_metadata_from_prompts, zero_kv_cache,
patch_execute_model_with_seeds,
assert_logprobs_dict_allclose)
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
"""Test that the multi step worker checks for sufficient space in the KV
cache. It should throw if it cannot run all the steps.
"""
block_size = 16
num_gpu_blocks = 2048 // block_size
prompts = [
list(range(block_size * 3)),
list(range(block_size * 2)),
]
prev_output_tokens = [
list(range(block_size * 1)),
list(range(block_size * 2)),
]
final_seq_lens = [
len(prompt + output) + num_steps
for prompt, output in zip(prompts, prev_output_tokens)
]
inputs = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_seq_lens,
continuations=prev_output_tokens)
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
worker = MagicMock()
worker.model_runner.block_size = block_size
for seq_group_metadata in inputs:
original_block_tables = seq_group_metadata.block_tables
# No exception.
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = {
seq_id: []
for seq_id, physical_blocks in original_block_tables.items()
}
# Expect exception.
with pytest.raises(ValueError,
match='times but found insufficient KV space for'):
assert_enough_kv_space(worker, inputs, num_steps)
seq_group_metadata.block_tables = original_block_tables
@torch.inference_mode()
def test_same_output_for_single_step():
"""Verify the multi step worker produces the same output as the normal
worker for num_steps=1.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 32
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
multi_step_worker.model_runner = worker.model_runner
multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1
prompts = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, num_gpu_blocks, block_size,
final_seq_lens=final_seq_lens))
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
actual_output = multi_step_worker.execute_model_multi_step(
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
assert len(actual_output) == num_steps
actual_output = actual_output[0]
zero_kv_cache(worker.cache_engine)
set_random_seed(seed)
expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), )
actual_token_ids = [
output.samples[0].output_token for output in actual_output
]
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
expected_token_ids = [
output.samples[0].output_token for output in expected_output
]
expected_logprobs = [
output.samples[0].logprobs for output in expected_output
]
assert actual_token_ids == expected_token_ids
print(f'{actual_logprobs=}')
print(f'{expected_logprobs=}')
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
@torch.inference_mode()
def test_same_output_for_multi_step():
"""Verify the multi-step worker produces the same output as the normal
worker when num_steps > 1. This test runs the multi-step worker once, and
then runs the worker num_steps times, and compares the output.
"""
seed = 100
model_name = 'JackFram/llama-68m'
block_size = 16
num_gpu_blocks = 2048 // block_size
multi_step_worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
worker = create_worker(
Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
)
# Make sure we go over the block boundary.
num_steps = block_size + 1
random.seed(seed)
prompts = [[
random.randint(0, 1000) for _ in range(random.randint(10, 20))
] for _ in range(10)]
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
multi_step_worker.execute_model = patch_execute_model_with_seeds(
multi_step_worker, rand_seeds)
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts]
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens), )
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed)
multi_step_output = multi_step_worker.execute_model_multi_step(
**execute_model_data.to_dict(), num_steps=num_steps)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
single_step_output = []
continuations = [[1] for _ in prompts]
set_random_seed(seed)
for _ in multi_step_output:
execute_model_data = create_execute_model_data(
create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=continuations,
final_seq_lens=final_seq_lens))
single_step_output.append(
worker.execute_model(**execute_model_data.to_dict(), ))
# Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]):
continuations[i].append(seq_group_output.samples[0].output_token)
# Get token ids and logprobs for comparison.
multi_step_output_logprobs = [[] for _ in prompts]
single_step_output_logprobs = [[] for _ in prompts]
multi_step_output_token_ids = [[] for _ in prompts]
single_step_output_token_ids = [[] for _ in prompts]
for i, _ in enumerate(prompts):
for multi_step, single_step in zip(multi_step_output,
single_step_output):
multi_step_output_token_ids[i].append(
multi_step[i].samples[0].output_token)
single_step_output_token_ids[i].append(
single_step[i].samples[0].output_token)
multi_step_output_logprobs[i].append(
multi_step[i].samples[0].logprobs)
single_step_output_logprobs[i].append(
single_step[i].samples[0].logprobs)
# Print per-sequence token ids
for i, (multi_step_tokens, single_step_tokens) in enumerate(
zip(multi_step_output_token_ids, single_step_output_token_ids)):
print(f'{i=} {multi_step_tokens=}')
print(f'{i=} {single_step_tokens=}')
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
# Assert token ids are equal.
for multi_step_tokens, single_step_tokens in zip(
multi_step_output_token_ids, single_step_output_token_ids):
assert multi_step_tokens == single_step_tokens
# Assert logprobs are equal.
for multi_step_logprobs, single_step_logprobs in zip(
multi_step_output_logprobs, single_step_output_logprobs):
assert_logprobs_dict_allclose(multi_step_logprobs,
single_step_logprobs)
import torch
from typing import List, Optional, Dict
from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed
from dataclasses import dataclass, fields
@dataclass
class ExecuteModelData:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list: List[SequenceGroupMetadata]
blocks_to_swap_in: Dict[int, int]
blocks_to_swap_out: Dict[int, int]
blocks_to_copy: Dict[int, List[int]]
def to_dict(self):
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size
def create_execute_model_data(
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, int]] = None,
) -> ExecuteModelData:
if blocks_to_swap_in is None:
blocks_to_swap_in = {}
if blocks_to_swap_out is None:
blocks_to_swap_out = {}
if blocks_to_copy is None:
blocks_to_copy = {}
return ExecuteModelData(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
seed_iter = iter(rand_seeds)
original_execute_model = worker.execute_model
def new_execute_model(*args, **kwargs):
result = original_execute_model(*args, **kwargs)
set_random_seed(next(seed_iter))
return result
return new_execute_model
def zero_kv_cache(cache_engine: CacheEngine):
assert cache_engine.gpu_cache
for key_blocks, value_blocks in cache_engine.gpu_cache:
key_blocks.zero_()
value_blocks.zero_()
def create_worker(cls: type,
model_name: str,
block_size: int,
num_gpu_blocks: int,
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True):
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
)
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = cls(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
worker.init_model()
worker.load_model()
cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0
worker.init_cache_engine(cache_config)
worker.warm_up_model()
return worker
def create_seq_group_metadata_from_prompts(
prompts: List[List[int]],
num_gpu_blocks: int,
block_size: int,
final_seq_lens: List[int],
continuations: Optional[List[List[int]]] = None,
num_tokens_processed: Optional[List[int]] = None,
seq_ids: Optional[List[int]] = None,
) -> List[SequenceGroupMetadata]:
if continuations is None:
continuations = [[] for _ in prompts]
if num_tokens_processed is None:
# Default to 1 token missing from kv cache for generation sequences.
num_tokens_processed = []
for continuation, prompt in zip(continuations, prompts):
# If prefill, then default to zero tokens processed.
if not continuation:
num_tokens_processed.append(0)
else:
# If generation, then default to all but one tokens processed.
num_tokens_processed.append(
len(continuation) + len(prompt) - 1)
if seq_ids is None:
seq_ids = list(i for i, _ in enumerate(prompts))
free_gpu_blocks = list(range(num_gpu_blocks))
block_allocations = {
i: [
free_gpu_blocks.pop()
for _ in range(round_up_to_next_block(final_len, block_size))
]
for i, final_len in enumerate(final_seq_lens)
}
return [
SequenceGroupMetadata(
request_id=str(i),
is_prompt=len(cont_token_ids) == 0,
seq_data={
i:
SequenceData(prompt_token_ids=prompt_token_ids[:] +
cont_token_ids[:])
},
sampling_params=SamplingParams(temperature=0.0, ),
block_tables={i: block_allocations[i][:]},
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in
enumerate(zip(prompts, continuations, num_tokens_processed))
]
def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]],
expected_logprobs: List[Dict[int, float]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id])
expected = torch.tensor(single_step_expected_logprobs[token_id])
assert torch.allclose(actual, expected)
...@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner ...@@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def test_prepare_prompt(): def test_prepare_prompt():
model_runner = ModelRunner(None, None, None) model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
...@@ -33,11 +33,12 @@ def test_prepare_prompt(): ...@@ -33,11 +33,12 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens = ( input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list)) model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens) prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len) assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len) assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
......
...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput ...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__ from vllm.version import __dcu_version__
__version__ = "0.2.7" __version__ = "0.3.3"
__all__ = [ __all__ = [
"LLM", "LLM",
......
...@@ -66,3 +66,7 @@ class PhysicalTokenBlock: ...@@ -66,3 +66,7 @@ class PhysicalTokenBlock:
return (f'PhysicalTokenBlock(device={self.device}, ' return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, ' f'block_number={self.block_number}, '
f'ref_count={self.ref_count})') f'ref_count={self.ref_count})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
from typing import Optional, Union from typing import Optional, Union, ClassVar
from dataclasses import dataclass
import os import os
from packaging.version import Version
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -42,6 +44,9 @@ class ModelConfig: ...@@ -42,6 +44,9 @@ class ModelConfig:
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default a tag name, or a commit id. If unspecified, will use the default
version. version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use branch name, a tag name, or a commit id. If unspecified, will use
the default version. the default version.
...@@ -68,6 +73,7 @@ class ModelConfig: ...@@ -68,6 +73,7 @@ class ModelConfig:
dtype: Union[str, torch.dtype], dtype: Union[str, torch.dtype],
seed: int, seed: int,
revision: Optional[str] = None, revision: Optional[str] = None,
code_revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
...@@ -82,6 +88,7 @@ class ModelConfig: ...@@ -82,6 +88,7 @@ class ModelConfig:
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.code_revision = code_revision
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
...@@ -91,14 +98,18 @@ class ModelConfig: ...@@ -91,14 +98,18 @@ class ModelConfig:
# download model from ModelScope hub, # download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use. # lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
model_path = snapshot_download(model_id=model, if not os.path.exists(model):
cache_dir=download_dir, model_path = snapshot_download(model_id=model,
revision=revision) cache_dir=download_dir,
revision=revision)
else:
model_path = model
self.model = model_path self.model = model_path
self.download_dir = model_path self.download_dir = model_path
self.tokenizer = model_path self.tokenizer = model_path
self.hf_config = get_config(self.model, trust_remote_code, revision) self.hf_config = get_config(self.model, trust_remote_code, revision,
code_revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config, self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len) max_model_len)
...@@ -144,15 +155,21 @@ class ModelConfig: ...@@ -144,15 +155,21 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"] supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq"] rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available. # Parse quantization method from the HF model config, if available.
hf_quant_config = getattr(self.hf_config, "quantization_config", None) hf_quant_config = getattr(self.hf_config, "quantization_config", None)
if hf_quant_config is not None: if hf_quant_config is not None:
hf_quant_method = str(hf_quant_config["quant_method"]).lower() hf_quant_method = str(hf_quant_config["quant_method"]).lower()
# If the GPTQ model is serialized in marlin format, use marlin.
if (hf_quant_method == "gptq"
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):
hf_quant_method = "marlin"
if self.quantization is None: if self.quantization is None:
self.quantization = hf_quant_method self.quantization = hf_quant_method
elif self.quantization != hf_quant_method: elif self.quantization != hf_quant_method:
...@@ -172,9 +189,11 @@ class ModelConfig: ...@@ -172,9 +189,11 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not supported " f"{self.quantization} quantization is currently not supported "
f"in ROCm.") f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully " if self.quantization != "marlin":
"optimized yet. The speed can be slower than " logger.warning(
"non-quantized models.") f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None: if self.max_context_len_to_capture is None:
...@@ -212,6 +231,8 @@ class ModelConfig: ...@@ -212,6 +231,8 @@ class ModelConfig:
return self.hf_config.hidden_size return self.hf_config.hidden_size
def get_head_size(self) -> int: def get_head_size(self) -> int:
if hasattr(self.hf_config, "head_dim"):
return self.hf_config.head_dim
# FIXME(woosuk): This may not be true for all models. # FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
...@@ -272,6 +293,7 @@ class CacheConfig: ...@@ -272,6 +293,7 @@ class CacheConfig:
gpu_memory_utilization: Fraction of GPU memory to use for the gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution. vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB). swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
""" """
def __init__( def __init__(
...@@ -279,24 +301,53 @@ class CacheConfig: ...@@ -279,24 +301,53 @@ class CacheConfig:
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
swap_space: int, swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
self._verify_args() self._verify_args()
self._verify_cache_dtype()
# Will be set after profiling. # Will be set after profiling.
self.num_gpu_blocks = None self.num_gpu_blocks = None
self.num_cpu_blocks = None self.num_cpu_blocks = None
def metrics_info(self):
# convert cache_config to dict(key: str, value:str) for prometheus metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.gpu_memory_utilization > 1.0: if self.gpu_memory_utilization > 1.0:
raise ValueError( raise ValueError(
"GPU memory utilization must be less than 1.0. Got " "GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.") f"{self.gpu_memory_utilization}.")
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8_e5m2":
nvcc_cuda_version = get_nvcc_cuda_version()
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
raise ValueError(
"FP8 is not supported when cuda version is lower than 11.8."
)
device_name = torch.cuda.get_device_name()
if "AMD" in device_name:
raise NotImplementedError(
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
logger.info(
"Using fp8_e5m2 data type to store kv cache. It reduces "
"the GPU memory footprint and boosts the performance. "
"But it may cause slight accuracy drop. "
"Currently we only support fp8 without scaling factors and "
"make e5m2 as a default format.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
...@@ -325,6 +376,11 @@ class ParallelConfig: ...@@ -325,6 +376,11 @@ class ParallelConfig:
worker_use_ray: Whether to use Ray for model workers. Will be set to worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1. greater than 1.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
""" """
def __init__( def __init__(
...@@ -333,14 +389,24 @@ class ParallelConfig: ...@@ -333,14 +389,24 @@ class ParallelConfig:
tensor_parallel_size: int, tensor_parallel_size: int,
worker_use_ray: bool, worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None, max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1: # Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True self.worker_use_ray = True
self._verify_args() self._verify_args()
...@@ -348,6 +414,26 @@ class ParallelConfig: ...@@ -348,6 +414,26 @@ class ParallelConfig:
if self.pipeline_parallel_size > 1: if self.pipeline_parallel_size > 1:
raise NotImplementedError( raise NotImplementedError(
"Pipeline parallelism is not supported yet.") "Pipeline parallelism is not supported yet.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
if not self.disable_custom_all_reduce and self.world_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved.")
class SchedulerConfig: class SchedulerConfig:
...@@ -397,6 +483,81 @@ class SchedulerConfig: ...@@ -397,6 +483,81 @@ class SchedulerConfig:
f"({self.max_num_seqs}).") f"({self.max_num_seqs}).")
class DeviceConfig:
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass
class LoRAConfig:
max_lora_rank: int
max_loras: int
max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
def __post_init__(self):
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
possible_max_ranks = (8, 16, 32, 64)
possible_lora_extra_vocab_size = (0, 256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}.")
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
f"must be one of {possible_lora_extra_vocab_size}.")
if self.max_loras < 1:
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
raise ValueError(
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})")
def verify_with_model_config(self, model_config: ModelConfig):
if self.lora_dtype in (None, "auto"):
self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None:
raise ValueError(
"LoRA is not supported with quantized models yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528:
raise ValueError(
"Due to limitations of the custom LoRA CUDA kernel, "
"max_num_batched_tokens must be <= 65528 when "
"LoRA is enabled.")
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,
......
...@@ -2,13 +2,10 @@ ...@@ -2,13 +2,10 @@
import enum import enum
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from vllm.block import PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockAllocator: class BlockAllocator:
"""Manages free physical token blocks for a device. """Manages free physical token blocks for a device.
...@@ -105,6 +102,10 @@ class BlockSpaceManager: ...@@ -105,6 +102,10 @@ class BlockSpaceManager:
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
if seq_group.prefix is not None and seq_group.prefix.allocated:
num_required_blocks -= seq_group.prefix.get_num_blocks()
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
self.block_sliding_window) self.block_sliding_window)
...@@ -125,8 +126,21 @@ class BlockSpaceManager: ...@@ -125,8 +126,21 @@ class BlockSpaceManager:
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = len(seq.logical_token_blocks)
block_table: BlockTable = [] block_table: BlockTable = []
for logical_idx in range(len(seq.logical_token_blocks)): prefix_block_table: BlockTable = []
num_prefix_blocks = 0
prefix = seq_group.prefix
if prefix is not None and prefix.allocated:
# Prefix has already been allocated. Use the existing block table.
num_prompt_blocks -= prefix.get_num_blocks()
for block in prefix.block_table:
block.ref_count += seq_group.num_seqs()
block_table.append(block)
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window] block = block_table[logical_idx % self.block_sliding_window]
...@@ -136,6 +150,15 @@ class BlockSpaceManager: ...@@ -136,6 +150,15 @@ class BlockSpaceManager:
block.ref_count = seq_group.num_seqs() block.ref_count = seq_group.num_seqs()
block_table.append(block) block_table.append(block)
if prefix is not None and not prefix.allocated:
# Allocate blocks for the prefix, we will compute the prefix's
# KV cache in this run.
num_prefix_blocks = prefix.get_num_blocks()
prefix_block_table = block_table[:num_prefix_blocks]
for block in prefix_block_table:
block.ref_count += 1
prefix.set_block_table(prefix_block_table)
# Assign the block table for each sequence. # Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[seq.seq_id] = block_table.copy()
...@@ -155,7 +178,7 @@ class BlockSpaceManager: ...@@ -155,7 +178,7 @@ class BlockSpaceManager:
if len(block_table) < len(logical_blocks): if len(block_table) < len(logical_blocks):
if (self.block_sliding_window if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window): and len(block_table) >= self.block_sliding_window):
# re-use a block # reuse a block
block_table.append(block_table[len(block_table) % block_table.append(block_table[len(block_table) %
self.block_sliding_window]) self.block_sliding_window])
else: else:
...@@ -210,10 +233,18 @@ class BlockSpaceManager: ...@@ -210,10 +233,18 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block. # CPU block -> GPU block.
if seq_group.prefix is not None:
# make sure to swap in the prefix first
assert seq_group.prefix.allocated and seq_group.prefix.computed
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = [] new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
if seq_group.prefix is not None:
for block in seq_group.prefix.block_table:
new_block_table.append(block)
block.ref_count += 1
for cpu_block in block_table: for cpu_block in block_table:
if cpu_block in mapping: if cpu_block in mapping:
...@@ -245,6 +276,12 @@ class BlockSpaceManager: ...@@ -245,6 +276,12 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table: for gpu_block in block_table:
if (seq_group.prefix is not None
and gpu_block in seq_group.prefix.block_table):
# NOTE: We do not swap out the prefix blocks for now.
self.gpu_allocator.free(gpu_block)
continue
if gpu_block in mapping: if gpu_block in mapping:
cpu_block = mapping[gpu_block] cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1 cpu_block.ref_count += 1
......
...@@ -33,7 +33,7 @@ class FCFS(Policy): ...@@ -33,7 +33,7 @@ class FCFS(Policy):
now: float, now: float,
seq_group: SequenceGroup, seq_group: SequenceGroup,
) -> float: ) -> float:
return now - seq_group.arrival_time return now - seq_group.metrics.arrival_time
class PolicyFactory: class PolicyFactory:
......
from collections import deque from collections import deque
import enum import enum
import time import time
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.lora.request import LoRARequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus) SequenceGroupMetadata, SequenceStatus)
from vllm.prefix import PrefixPool
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -48,11 +50,25 @@ class SchedulerOutputs: ...@@ -48,11 +50,25 @@ class SchedulerOutputs:
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups self.ignored_seq_groups = ignored_seq_groups
self.num_loras = len(self.lora_requests)
if self.num_loras > 0:
self._sort_by_lora_ids()
def is_empty(self) -> bool: def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups. # NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool:
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.lora_request.lora_int_id
if g.lora_request else 0, g.request_id))
@property
def lora_requests(self) -> Set[LoRARequest]:
return {g.lora_request for g in self.scheduled_seq_groups}
class Scheduler: class Scheduler:
...@@ -60,9 +76,14 @@ class Scheduler: ...@@ -60,9 +76,14 @@ class Scheduler:
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
self.prompt_limit = min(self.scheduler_config.max_model_len, self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
...@@ -76,6 +97,9 @@ class Scheduler: ...@@ -76,6 +97,9 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window) sliding_window=self.cache_config.sliding_window)
# Create the prefix pool to cache the prefixes.
self.prefix_pool = PrefixPool(self.cache_config.block_size)
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: Deque[SequenceGroup] = deque() self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
...@@ -83,6 +107,10 @@ class Scheduler: ...@@ -83,6 +107,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque() self.swapped: Deque[SequenceGroup] = deque()
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
...@@ -104,7 +132,7 @@ class Scheduler: ...@@ -104,7 +132,7 @@ class Scheduler:
request_id = (request_id, ) request_id = (request_id, )
request_ids = set(request_id) request_ids = set(request_id)
for state_queue in [self.waiting, self.running, self.swapped]: for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups = [] aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue: for seq_group in state_queue:
if not request_ids: if not request_ids:
# Using 'break' here may add two extra iterations, # Using 'break' here may add two extra iterations,
...@@ -117,7 +145,7 @@ class Scheduler: ...@@ -117,7 +145,7 @@ class Scheduler:
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
for seq in seq_group.get_seqs(): for seq in aborted_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
seq.status = SequenceStatus.FINISHED_ABORTED seq.status = SequenceStatus.FINISHED_ABORTED
...@@ -130,7 +158,7 @@ class Scheduler: ...@@ -130,7 +158,7 @@ class Scheduler:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
...@@ -146,14 +174,17 @@ class Scheduler: ...@@ -146,14 +174,17 @@ class Scheduler:
# requests in the generation phase. # requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = [] seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
leftover_waiting_sequences = deque()
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs( waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING) status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, ( assert len(waiting_seqs) == 1, (
...@@ -184,6 +215,17 @@ class Scheduler: ...@@ -184,6 +215,17 @@ class Scheduler:
self.waiting.popleft() self.waiting.popleft()
continue continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
self.waiting.popleft()
continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens] new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
...@@ -203,12 +245,16 @@ class Scheduler: ...@@ -203,12 +245,16 @@ class Scheduler:
break break
seq_lens = new_seq_lens seq_lens = new_seq_lens
seq_group = self.waiting.popleft() if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
scheduled.append(seq_group) scheduled.append(seq_group)
self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups: if scheduled or ignored_seq_groups:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
...@@ -256,9 +302,25 @@ class Scheduler: ...@@ -256,9 +302,25 @@ class Scheduler:
if not preempted: if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs() num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running) for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
leftover_swapped = deque()
while self.swapped: while self.swapped:
seq_group = self.swapped[0] seq_group = self.swapped[0]
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
curr_loras) >= self.lora_config.max_loras:
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
self.swapped.popleft()
continue
# If the sequence group cannot be swapped in, stop. # If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group): if not self.block_manager.can_swap_in(seq_group):
break break
...@@ -270,12 +332,16 @@ class Scheduler: ...@@ -270,12 +332,16 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.popleft() if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
self.running.append(seq_group) self.running.append(seq_group)
self.swapped.extendleft(leftover_swapped)
# Each sequence in the generation phase only takes one token slot. # Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of # Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state. # sequences in the RUNNING state.
...@@ -299,10 +365,13 @@ class Scheduler: ...@@ -299,10 +365,13 @@ class Scheduler:
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
scheduler_outputs = self._schedule() scheduler_outputs = self._schedule()
now = time.time()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups: for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_group.maybe_set_first_scheduled_time(now)
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
...@@ -316,6 +385,9 @@ class Scheduler: ...@@ -316,6 +385,9 @@ class Scheduler:
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs return seq_group_metadata_list, scheduler_outputs
...@@ -327,10 +399,8 @@ class Scheduler: ...@@ -327,10 +399,8 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
self.running = [ self.running = deque(seq_group for seq_group in self.running
seq_group for seq_group in self.running if not seq_group.is_finished())
if not seq_group.is_finished()
]
def _allocate(self, seq_group: SequenceGroup) -> None: def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
......
...@@ -3,8 +3,8 @@ import dataclasses ...@@ -3,8 +3,8 @@ import dataclasses
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
SchedulerConfig) ParallelConfig, SchedulerConfig, LoRAConfig)
@dataclass @dataclass
...@@ -17,6 +17,7 @@ class EngineArgs: ...@@ -17,6 +17,7 @@ class EngineArgs:
download_dir: Optional[str] = None download_dir: Optional[str] = None
load_format: str = 'auto' load_format: str = 'auto'
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
...@@ -31,10 +32,19 @@ class EngineArgs: ...@@ -31,10 +32,19 @@ class EngineArgs:
max_paddings: int = 256 max_paddings: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False enforce_eager: bool = False
max_context_len_to_capture: int = 8192 max_context_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -66,6 +76,13 @@ class EngineArgs: ...@@ -66,6 +76,13 @@ class EngineArgs:
help='the specific model version to use. It can be a branch ' help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument(
'--code-revision',
type=str,
default=None,
help='the specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=str,
...@@ -115,9 +132,17 @@ class EngineArgs: ...@@ -115,9 +132,17 @@ class EngineArgs:
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8_e5m2'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is '
'lower than 11.8.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=None, default=EngineArgs.max_model_len,
help='model context length. If unspecified, ' help='model context length. If unspecified, '
'will be automatically derived from the model.') 'will be automatically derived from the model.')
# Parallel arguments # Parallel arguments
...@@ -138,6 +163,7 @@ class EngineArgs: ...@@ -138,6 +163,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
type=int, type=int,
default=EngineArgs.max_parallel_loading_workers,
help='load model sequentially in multiple batches, ' help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor ' 'to avoid RAM OOM when using tensor '
'parallel and large models') 'parallel and large models')
...@@ -145,9 +171,8 @@ class EngineArgs: ...@@ -145,9 +171,8 @@ class EngineArgs:
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32, 128],
help='token block size') help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
...@@ -184,7 +209,7 @@ class EngineArgs: ...@@ -184,7 +209,7 @@ class EngineArgs:
'-q', '-q',
type=str, type=str,
choices=['awq', 'gptq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` ' 'None, we first check the `quantization_config` '
'attribute in the model config file. If that is ' 'attribute in the model config file. If that is '
...@@ -202,6 +227,48 @@ class EngineArgs: ...@@ -202,6 +227,48 @@ class EngineArgs:
help='maximum context length covered by CUDA ' help='maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce',
action='store_true',
default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
return parser return parser
@classmethod @classmethod
...@@ -214,27 +281,37 @@ class EngineArgs: ...@@ -214,27 +281,37 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
model_config = ModelConfig(self.model, self.tokenizer, DeviceConfig, Optional[LoRAConfig]]:
self.tokenizer_mode, self.trust_remote_code, device_config = DeviceConfig(self.device)
self.download_dir, self.load_format, model_config = ModelConfig(
self.dtype, self.seed, self.revision, self.model, self.tokenizer, self.tokenizer_mode,
self.tokenizer_revision, self.max_model_len, self.trust_remote_code, self.download_dir, self.load_format,
self.quantization, self.enforce_eager, self.dtype, self.seed, self.revision, self.code_revision,
self.max_context_len_to_capture) self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window()) model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.worker_use_ray,
self.max_parallel_loading_workers) self.max_parallel_loading_workers,
self.disable_custom_all_reduce)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len,
self.max_paddings) self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config)
@dataclass @dataclass
......
...@@ -4,6 +4,7 @@ from functools import partial ...@@ -4,6 +4,7 @@ from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator) Union, AsyncIterator)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
...@@ -52,7 +53,7 @@ class AsyncStream: ...@@ -52,7 +53,7 @@ class AsyncStream:
self._queue.put_nowait(item) self._queue.put_nowait(item)
def finish(self) -> None: def finish(self) -> None:
self._queue.put_nowait(StopIteration) self._queue.put_nowait(StopAsyncIteration())
self._finished = True self._finished = True
@property @property
...@@ -64,9 +65,7 @@ class AsyncStream: ...@@ -64,9 +65,7 @@ class AsyncStream:
async def __anext__(self) -> RequestOutput: async def __anext__(self) -> RequestOutput:
result = await self._queue.get() result = await self._queue.get()
if result is StopIteration: if isinstance(result, Exception):
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result raise result
return result return result
...@@ -203,6 +202,52 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -203,6 +202,52 @@ class _AsyncLLMEngine(LLMEngine):
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async( async def _run_workers_async(
self, self,
method: str, method: str,
...@@ -251,6 +296,8 @@ class AsyncLLMEngine: ...@@ -251,6 +296,8 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for LLMEngine. *args: Arguments for LLMEngine.
...@@ -286,6 +333,9 @@ class AsyncLLMEngine: ...@@ -286,6 +333,9 @@ class AsyncLLMEngine:
return (self.background_loop is not None return (self.background_loop is not None
and not self.background_loop.done()) and not self.background_loop.done())
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
if self.is_running: if self.is_running:
...@@ -332,7 +382,7 @@ class AsyncLLMEngine: ...@@ -332,7 +382,7 @@ class AsyncLLMEngine:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote(**new_request) await self.engine.add_request.remote(**new_request)
else: else:
self.engine.add_request(**new_request) await self.engine.add_request_async(**new_request)
if finished_requests: if finished_requests:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)
...@@ -371,6 +421,8 @@ class AsyncLLMEngine: ...@@ -371,6 +421,8 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt shortened_prompt = prompt
...@@ -383,8 +435,10 @@ class AsyncLLMEngine: ...@@ -383,8 +435,10 @@ class AsyncLLMEngine:
max_log_len] max_log_len]
logger.info(f"Received request {request_id}: " logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, " f"prompt: {shortened_prompt!r}, "
f"sampling params: {sampling_params}, " f"prefix_pos: {prefix_pos},"
f"prompt token ids: {shortened_token_ids}.") f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
...@@ -396,12 +450,30 @@ class AsyncLLMEngine: ...@@ -396,12 +450,30 @@ class AsyncLLMEngine:
"error that caused the background loop to stop " "error that caused the background loop to stop "
"(AsyncEngineDeadError).") "(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, prompt=prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time) arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
return stream return stream
...@@ -410,7 +482,9 @@ class AsyncLLMEngine: ...@@ -410,7 +482,9 @@ class AsyncLLMEngine:
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -425,6 +499,12 @@ class AsyncLLMEngine: ...@@ -425,6 +499,12 @@ class AsyncLLMEngine:
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine for the
...@@ -478,11 +558,15 @@ class AsyncLLMEngine: ...@@ -478,11 +558,15 @@ class AsyncLLMEngine:
arrival_time = time.monotonic() arrival_time = time.monotonic()
try: try:
stream = await self.add_request(request_id, stream = await self.add_request(
prompt, request_id,
sampling_params, prompt,
prompt_token_ids=prompt_token_ids, sampling_params,
arrival_time=arrival_time) prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream: async for request_output in stream:
yield request_output yield request_output
......
...@@ -2,14 +2,17 @@ import copy ...@@ -2,14 +2,17 @@ import copy
from collections import defaultdict from collections import defaultdict
import os import os
import time import time
import pickle
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union) Union)
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.lora.request import LoRARequest
SchedulerConfig) from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import record_metrics from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -17,8 +20,9 @@ from vllm.sampling_params import SamplingParams ...@@ -17,8 +20,9 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
get_open_port, get_distributed_init_method)
if ray: if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
...@@ -27,8 +31,18 @@ if TYPE_CHECKING: ...@@ -27,8 +31,18 @@ if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
_LOGGING_INTERVAL_SEC = 5 # A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class LLMEngine: class LLMEngine:
...@@ -53,6 +67,7 @@ class LLMEngine: ...@@ -53,6 +67,7 @@ class LLMEngine:
management. management.
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution. placement_group: Ray placement group for distributed execution.
Required for distributed execution. Required for distributed execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
...@@ -64,6 +79,8 @@ class LLMEngine: ...@@ -64,6 +79,8 @@ class LLMEngine:
cache_config: CacheConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
...@@ -80,24 +97,24 @@ class LLMEngine: ...@@ -80,24 +97,24 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, " f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"device_config={device_config.device}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config
self.log_stats = log_stats self.log_stats = log_stats
self._verify_args() self._verify_args()
self.tokenizer = get_tokenizer( self._init_tokenizer()
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
...@@ -114,37 +131,67 @@ class LLMEngine: ...@@ -114,37 +131,67 @@ class LLMEngine:
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config) self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model))
self.stat_logger.info("cache_config", self.cache_config)
# Logging. self.forward_dag = None
self.last_logging_time = 0.0 if USE_RAY_COMPILED_DAG:
# List of (timestamp, num_tokens) self.forward_dag = self._compiled_ray_dag()
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens) def get_tokenizer_for_seq(self, sequence: Sequence):
self.num_generation_tokens: List[Tuple[float, int]] = [] return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers(self): def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker Worker = self._dispatch_worker()
assert self.parallel_config.world_size == 1, ( assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = [] self.workers: List[Worker] = []
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker( self.driver_worker = Worker(
self.model_config, self.model_config,
self.parallel_config, self.parallel_config,
self.scheduler_config, self.scheduler_config,
self.device_config,
local_rank=0, local_rank=0,
rank=0, rank=0,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
self._run_workers("init_model") self._run_workers("init_model")
self._run_workers("load_model") self._run_workers("load_model")
def _init_tokenizer(self, **tokenizer_init_kwargs):
init_kwargs = dict(
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: TokenizerGroup = TokenizerGroup(
self.model_config.tokenizer, **init_kwargs)
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
...@@ -207,16 +254,18 @@ class LLMEngine: ...@@ -207,16 +254,18 @@ class LLMEngine:
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id]) worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}" distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port())
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker Worker = self._dispatch_worker()
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config) parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config) scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
for rank, (worker, (node_id, for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers, _)) in enumerate(zip(self.workers,
...@@ -228,9 +277,12 @@ class LLMEngine: ...@@ -228,9 +277,12 @@ class LLMEngine:
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config,
local_rank, local_rank,
rank, rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
)) ))
driver_rank = 0 driver_rank = 0
...@@ -239,13 +291,19 @@ class LLMEngine: ...@@ -239,13 +291,19 @@ class LLMEngine:
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config,
driver_local_rank, driver_local_rank,
driver_rank, driver_rank,
distributed_init_method, distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True, is_driver_worker=True,
) )
self._run_workers("init_model") # don't use cupy for eager mode
self._run_workers("init_model",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers( self._run_workers(
"load_model", "load_model",
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
...@@ -255,6 +313,10 @@ class LLMEngine: ...@@ -255,6 +313,10 @@ class LLMEngine:
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def _init_cache(self) -> None: def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache. """Profiles the memory usage and initializes the KV cache.
...@@ -283,6 +345,7 @@ class LLMEngine: ...@@ -283,6 +345,7 @@ class LLMEngine:
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes, cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
) )
# Since we use a shared centralized controller, we take the minimum # Since we use a shared centralized controller, we take the minimum
...@@ -330,6 +393,20 @@ class LLMEngine: ...@@ -330,6 +393,20 @@ class LLMEngine:
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
def encode_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
...@@ -337,6 +414,8 @@ class LLMEngine: ...@@ -337,6 +414,8 @@ class LLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None, prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -353,6 +432,11 @@ class LLMEngine: ...@@ -353,6 +432,11 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
...@@ -378,20 +462,35 @@ class LLMEngine: ...@@ -378,20 +462,35 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.monotonic()
if prompt_token_ids is None: prompt_token_ids = self.encode_request(
assert prompt is not None request_id=request_id,
prompt_token_ids = self.tokenizer.encode(prompt) prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Create the sequence group. # Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time) arrival_time, lora_request, prefix)
# Add the sequence group to the scheduler. # Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group) self.scheduler.add_seq_group(seq_group)
...@@ -441,11 +540,13 @@ class LLMEngine: ...@@ -441,11 +540,13 @@ class LLMEngine:
current_worst_score = (current_worst_seq.get_beam_search_score( current_worst_score = (current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
current_worst_seq).eos_token_id))
if early_stopping is False: if early_stopping is False:
highest_attainable_score = (best_running_seq.get_beam_search_score( highest_attainable_score = (best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
else: else:
assert early_stopping == "never" assert early_stopping == "never"
if length_penalty > 0.0: if length_penalty > 0.0:
...@@ -459,7 +560,8 @@ class LLMEngine: ...@@ -459,7 +560,8 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id,
seq_len=max_possible_length)) seq_len=max_possible_length))
else: else:
# Otherwise, beam search will prefer shorter sequences. The # Otherwise, beam search will prefer shorter sequences. The
...@@ -468,11 +570,13 @@ class LLMEngine: ...@@ -468,11 +570,13 @@ class LLMEngine:
highest_attainable_score = ( highest_attainable_score = (
best_running_seq.get_beam_search_score( best_running_seq.get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id)) eos_token_id=self.get_tokenizer_for_seq(
best_running_seq).eos_token_id))
return current_worst_score >= highest_attainable_score return current_worst_score >= highest_attainable_score
def _process_sequence_group_outputs(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None: outputs: SequenceGroupOutput) -> None:
# Process prompt logprobs # Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None: if prompt_logprobs is not None:
...@@ -559,7 +663,7 @@ class LLMEngine: ...@@ -559,7 +663,7 @@ class LLMEngine:
# Sort the finished sequences by their scores. # Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id), eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]: for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new: if is_new:
...@@ -587,7 +691,7 @@ class LLMEngine: ...@@ -587,7 +691,7 @@ class LLMEngine:
# Sort the running sequences by their scores. # Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id), eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
reverse=True) reverse=True)
# Check if we can stop the beam search. # Check if we can stop the beam search.
...@@ -645,6 +749,7 @@ class LLMEngine: ...@@ -645,6 +749,7 @@ class LLMEngine:
def _process_model_outputs( def _process_model_outputs(
self, output: SamplerOutput, self, output: SamplerOutput,
scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
now = time.time()
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, outputs in zip(scheduled_seq_groups, output): for seq_group, outputs in zip(scheduled_seq_groups, output):
...@@ -656,16 +761,23 @@ class LLMEngine: ...@@ -656,16 +761,23 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups: for seq_group in scheduled_seq_groups:
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
# Log stats.
if self.log_stats: if self.log_stats:
# Log the system stats. self.stat_logger.log(self._get_stats(scheduler_outputs))
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
...@@ -730,7 +842,8 @@ class LLMEngine: ...@@ -730,7 +842,8 @@ class LLMEngine:
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy, "blocks_to_copy": scheduler_outputs.blocks_to_copy,
}) },
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] output = all_outputs[0]
...@@ -740,86 +853,84 @@ class LLMEngine: ...@@ -740,86 +853,84 @@ class LLMEngine:
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
def do_log_stats(self) -> None: def do_log_stats(self) -> None:
self._log_system_stats(False, 0) """Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs=None))
def _log_system_stats( def _get_stats(self,
self, scheduler_outputs: Optional[SchedulerOutputs]) -> Stats:
prompt_run: bool, """Get Stats to be Logged to Prometheus."""
num_batched_tokens: int,
) -> None:
now = time.monotonic() now = time.monotonic()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC # KV Cache Usage in %.
if not should_log: num_total_gpu = self.cache_config.num_gpu_blocks
return num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
# Discard the old stats. num_total_cpu = self.cache_config.num_cpu_blocks
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens cpu_cache_usage = 0.
if now - t < _LOGGING_INTERVAL_SEC] if num_total_cpu > 0:
self.num_generation_tokens = [(t, n) num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
for t, n in self.num_generation_tokens )
if now - t < _LOGGING_INTERVAL_SEC] cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
if len(self.num_prompt_tokens) > 1: # Scheduler State
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) num_running = len(self.scheduler.running)
window = now - self.num_prompt_tokens[0][0] num_swapped = len(self.scheduler.swapped)
avg_prompt_throughput = total_num_tokens / window num_waiting = len(self.scheduler.waiting)
else:
avg_prompt_throughput = 0.0 # Iteration stats if we have scheduler output.
if len(self.num_generation_tokens) > 1: num_prompt_tokens = 0
total_num_tokens = sum(n num_generation_tokens = 0
for _, n in self.num_generation_tokens[:-1]) time_to_first_tokens = []
window = now - self.num_generation_tokens[0][0] time_per_output_tokens = []
avg_generation_throughput = total_num_tokens / window time_e2e_requests = []
else: if scheduler_outputs is not None:
avg_generation_throughput = 0.0 prompt_run = scheduler_outputs.prompt_run
total_num_gpu_blocks = self.cache_config.num_gpu_blocks # Number of Tokens.
num_free_gpu_blocks = ( if prompt_run:
self.scheduler.block_manager.get_num_free_gpu_blocks()) num_prompt_tokens = sum(
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks len(seq_group.prompt_token_ids)
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks for seq_group in scheduler_outputs.scheduled_seq_groups)
num_generation_tokens = sum(
total_num_cpu_blocks = self.cache_config.num_cpu_blocks seq_group.num_seqs()
if total_num_cpu_blocks > 0: for seq_group in scheduler_outputs.scheduled_seq_groups)
num_free_cpu_blocks = ( else:
self.scheduler.block_manager.get_num_free_cpu_blocks()) num_generation_tokens = scheduler_outputs.num_batched_tokens
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks # Latency Timings.
else: time_last_iters = []
cpu_cache_usage = 0.0 for seq_group in scheduler_outputs.scheduled_seq_groups:
# Time since last token. (n.b. updates seq_group.metrics.last_token_time)
record_metrics( time_last_iters.append(seq_group.get_last_latency(now))
avg_prompt_throughput=avg_prompt_throughput, # Time since arrival for all finished requests.
avg_generation_throughput=avg_generation_throughput, if seq_group.is_finished():
scheduler_running=len(self.scheduler.running), time_e2e_requests.append(now -
scheduler_swapped=len(self.scheduler.swapped), seq_group.metrics.arrival_time)
scheduler_waiting=len(self.scheduler.waiting),
time_to_first_tokens = time_last_iters if prompt_run else []
time_per_output_tokens = [] if prompt_run else time_last_iters
return Stats(
now=now,
num_running=num_running,
num_swapped=num_swapped,
num_waiting=num_waiting,
gpu_cache_usage=gpu_cache_usage, gpu_cache_usage=gpu_cache_usage,
cpu_cache_usage=cpu_cache_usage, cpu_cache_usage=cpu_cache_usage,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=num_generation_tokens,
time_to_first_tokens=time_to_first_tokens,
time_per_output_tokens=time_per_output_tokens,
time_e2e_requests=time_e2e_requests,
) )
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
(new_tokens, new_output_text, prefix_offset, (new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
self.tokenizer, self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(), all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
...@@ -840,13 +951,13 @@ class LLMEngine: ...@@ -840,13 +951,13 @@ class LLMEngine:
"""Stop the finished sequences.""" """Stop the finished sequences."""
for stop_str in sampling_params.stop: for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str): if seq.output_text.endswith(stop_str):
if not sampling_params.include_stop_str_in_output: self._finalize_sequence(seq, sampling_params, stop_str)
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
if seq.get_last_token_id() in sampling_params.stop_token_ids: if seq.get_last_token_id() in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
seq.get_last_token_id())
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
...@@ -861,11 +972,39 @@ class LLMEngine: ...@@ -861,11 +972,39 @@ class LLMEngine:
return return
# Check if the sequence has generated the EOS token. # Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos) if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
and seq.get_last_token_id() == self.tokenizer.eos_token_id): == self.get_tokenizer_for_seq(seq).eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams,
stop_string: str) -> None:
if sampling_params.include_stop_str_in_output:
return
if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"add_lora",
lora_request=lora_request,
)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"remove_lora",
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
return self._run_workers("list_loras")
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
...@@ -873,6 +1012,7 @@ class LLMEngine: ...@@ -873,6 +1012,7 @@ class LLMEngine:
driver_args: Optional[List[Any]] = None, driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
...@@ -881,11 +1021,16 @@ class LLMEngine: ...@@ -881,11 +1021,16 @@ class LLMEngine:
raise NotImplementedError( raise NotImplementedError(
"max_concurrent_workers is not supported yet.") "max_concurrent_workers is not supported yet.")
# Start the ray workers first. if use_ray_compiled_dag:
ray_worker_outputs = [ # Right now, compiled DAG can only accept a single
worker.execute_method.remote(method, *args, **kwargs) # input. TODO(sang): Fix it.
for worker in self.workers output_channels = self.forward_dag.execute(1)
] else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None: if driver_args is None:
driver_args = args driver_args = args
...@@ -898,6 +1043,37 @@ class LLMEngine: ...@@ -898,6 +1043,37 @@ class LLMEngine:
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs) if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs return [driver_worker_output] + ray_worker_outputs
def _compiled_ray_dag(self):
import pkg_resources
required_version = "2.9"
current_version = pkg_resources.get_distribution("ray").version
if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}")
from ray.dag import MultiOutputNode, InputNode
assert self.parallel_config.worker_use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with InputNode() as input_data:
forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote.bind(input_data)
for worker in self.workers
])
return forward_dag.experimental_compile()
from aioprometheus import Gauge from vllm.logger import init_logger
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
import time
import numpy as np
from typing import Dict, List
from dataclasses import dataclass
logger = init_logger(__name__)
disable_created_metrics()
# The begin-* and end* here are used by the documentation generator # The begin-* and end* here are used by the documentation generator
# to extract the metrics definitions. # to extract the metrics definitions.
# begin-metrics-definitions # begin-metrics-definitions
gauge_avg_prompt_throughput = Gauge("vllm:avg_prompt_throughput_toks_per_s", class Metrics:
"Average prefill throughput in tokens/s.")
gauge_avg_generation_throughput = Gauge( def __init__(self, labelnames: List[str]):
"vllm:avg_generation_throughput_toks_per_s", # Unregister any existing vLLM collectors
"Average generation throughput in tokens/s.") for collector in list(REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
gauge_scheduler_running = Gauge( REGISTRY.unregister(collector)
"vllm:num_requests_running",
"Number of requests that is currently running for inference.") self.info_cache_config = Info(
gauge_scheduler_swapped = Gauge("vllm:num_requests_swapped", name='vllm:cache_config',
"Number requests swapped to CPU.") documentation='information of cache_config')
gauge_scheduler_waiting = Gauge("vllm:num_requests_waiting",
"Number of requests waiting to be processed.") # System stats
self.gauge_scheduler_running = Gauge(
gauge_gpu_cache_usage = Gauge( name="vllm:num_requests_running",
"vllm:gpu_cache_usage_perc", documentation="Number of requests currently running on GPU.",
"GPU KV-cache usage. 1 means 100 percent usage.") labelnames=labelnames)
gauge_cpu_cache_usage = Gauge( self.gauge_scheduler_swapped = Gauge(
"vllm:cpu_cache_usage_perc", name="vllm:num_requests_swapped",
"CPU KV-cache usage. 1 means 100 percent usage.") documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = Gauge(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_gpu_cache_usage = Gauge(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = Gauge(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Raw stats from last model iteration
self.counter_prompt_tokens = Counter(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = Counter(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = Histogram(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
buckets=[
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = Histogram(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5
])
self.histogram_e2e_request_latency = Histogram(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Legacy metrics
self.gauge_avg_prompt_throughput = Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
self.gauge_avg_generation_throughput = Gauge(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
# end-metrics-definitions # end-metrics-definitions
labels = {}
@dataclass
class Stats:
def add_global_metrics_labels(**kwargs): """Created by LLMEngine for use by StatLogger."""
labels.update(kwargs) now: float
# System stats.
def record_metrics( num_running: int
avg_prompt_throughput: float, num_waiting: int
avg_generation_throughput: float, num_swapped: int
scheduler_running: int, gpu_cache_usage: float
scheduler_swapped: int, cpu_cache_usage: float
scheduler_waiting: int,
gpu_cache_usage: float, # Raw stats from last model iteration.
cpu_cache_usage: float, num_prompt_tokens: int
): num_generation_tokens: int
gauge_avg_prompt_throughput.set(labels, avg_prompt_throughput) time_to_first_tokens: List[float]
gauge_avg_generation_throughput.set(labels, avg_generation_throughput) time_per_output_tokens: List[float]
gauge_scheduler_running.set(labels, scheduler_running) time_e2e_requests: List[float]
gauge_scheduler_swapped.set(labels, scheduler_swapped)
gauge_scheduler_waiting.set(labels, scheduler_waiting)
gauge_gpu_cache_usage.set(labels, gpu_cache_usage) class StatLogger:
gauge_cpu_cache_usage.set(labels, cpu_cache_usage) """StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally.
self.last_local_log = time.monotonic()
self.local_interval = local_interval
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
# Prometheus metrics
self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None:
if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info())
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
return float(np.sum(tracked_stats) / (now - self.last_local_log))
def _local_interval_elapsed(self, now: float) -> bool:
elapsed_time = now - self.last_local_log
return elapsed_time > self.local_interval
def _log_prometheus(self, stats: Stats) -> None:
# Set system stat gauges.
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
stats.num_running)
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
stats.num_swapped)
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
stats.num_waiting)
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
stats.gpu_cache_usage)
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
stats.cpu_cache_usage)
# Add to token counters.
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
stats.num_prompt_tokens)
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
stats.num_generation_tokens)
# Observe request level latencies in histograms.
for ttft in stats.time_to_first_tokens:
self.metrics.histogram_time_to_first_token.labels(
**self.labels).observe(ttft)
for tpot in stats.time_per_output_tokens:
self.metrics.histogram_time_per_output_token.labels(
**self.labels).observe(tpot)
for e2e in stats.time_e2e_requests:
self.metrics.histogram_e2e_request_latency.labels(
**self.labels).observe(e2e)
def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
# Logs metrics to prometheus that are computed every logging_interval.
# Support legacy gauge metrics that make throughput calculations on the vLLM side.
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
self.metrics.gauge_avg_prompt_throughput.labels(
**self.labels).set(prompt_throughput)
self.metrics.gauge_avg_generation_throughput.labels(
**self.labels).set(generation_throughput)
def log(self, stats: Stats) -> None:
"""Called by LLMEngine.
Logs to prometheus and tracked stats every iteration.
Logs to Stdout every self.local_interval seconds."""
# Log to prometheus.
self._log_prometheus(stats)
# Save tracked stats for token counters.
self.num_prompt_tokens.append(stats.num_prompt_tokens)
self.num_generation_tokens.append(stats.num_generation_tokens)
# Log locally every local_interval seconds.
if self._local_interval_elapsed(stats.now):
# Compute summary metrics for tracked stats (and log them to promethus if applicable).
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
now=stats.now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now=stats.now)
self._log_prometheus_interval(
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)
# Log to stdout.
logger.info(
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, "
f"Running: {stats.num_running} reqs, "
f"Swapped: {stats.num_swapped} reqs, "
f"Pending: {stats.num_waiting} reqs, "
f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%")
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
import pickle
from typing import Optional, List, Tuple, TYPE_CHECKING from typing import Optional, List, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
...@@ -18,6 +20,11 @@ try: ...@@ -18,6 +20,11 @@ try:
from transformers.dynamic_module_utils import init_hf_modules from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules() init_hf_modules()
self.worker = None self.worker = None
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn):
self.worker = worker_init_fn() self.worker = worker_init_fn()
...@@ -40,10 +47,21 @@ try: ...@@ -40,10 +47,21 @@ try:
def set_cuda_visible_devices(self, device_ids) -> None: def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids) set_cuda_visible_devices(device_ids)
def execute_model_compiled_dag_remote(self, ignored):
"""Used only when compiled DAG is enabled."""
import torch
if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model()
output = pickle.dumps(output)
return output
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray pandas pyarrow`.") "`pip install ray`.")
ray = None ray = None
RayWorkerVllm = None RayWorkerVllm = None
...@@ -65,10 +83,9 @@ def initialize_cluster( ...@@ -65,10 +83,9 @@ def initialize_cluster(
the default Ray cluster address. the default Ray cluster address.
Returns: Returns:
A tuple of (`distributed_init_method`, `placement_group`). The An optional `PlacementGroup`. It includes the specification
`distributed_init_method` is the address for initializing the of the resources for each distributed worker. None if Ray is
distributed backend. `placement_group` includes the specification not used.
of the resources for each distributed worker.
""" """
if parallel_config.worker_use_ray or engine_use_ray: if parallel_config.worker_use_ray or engine_use_ray:
if ray is None: if ray is None:
......
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
...@@ -33,11 +39,15 @@ async def generate(request: Request) -> Response: ...@@ -33,11 +39,15 @@ async def generate(request: Request) -> Response:
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False) stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
# Streaming case # Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
......
...@@ -3,6 +3,7 @@ from typing import List, Optional, Union ...@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -63,6 +64,7 @@ class LLM: ...@@ -63,6 +64,7 @@ class LLM:
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode. to eager mode.
disable_custom_all_reduce: See ParallelConfig
""" """
def __init__( def __init__(
...@@ -81,6 +83,7 @@ class LLM: ...@@ -81,6 +83,7 @@ class LLM:
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -100,6 +103,7 @@ class LLM: ...@@ -100,6 +103,7 @@ class LLM:
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture, max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)
...@@ -107,20 +111,22 @@ class LLM: ...@@ -107,20 +111,22 @@ class LLM:
def get_tokenizer( def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer return self.llm_engine.tokenizer.tokenizer
def set_tokenizer( def set_tokenizer(
self, self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None: ) -> None:
self.llm_engine.tokenizer = tokenizer self.llm_engine.tokenizer.tokenizer = tokenizer
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
...@@ -134,7 +140,13 @@ class LLM: ...@@ -134,7 +140,13 @@ class LLM:
None, we use the default sampling parameters. None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns: Returns:
A list of `RequestOutput` objects containing the generated A list of `RequestOutput` objects containing the generated
...@@ -159,9 +171,14 @@ class LLM: ...@@ -159,9 +171,14 @@ class LLM:
prompt_token_ids) prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[
i] i]
self._add_request(prompt, sampling_params, token_ids) self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm) return self._run_engine(use_tqdm)
def _add_request( def _add_request(
...@@ -169,10 +186,16 @@ class LLM: ...@@ -169,10 +186,16 @@ class LLM:
prompt: Optional[str], prompt: Optional[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, prompt, sampling_params, self.llm_engine.add_request(request_id,
prompt_token_ids) prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm. # Initialize tqdm.
......
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse import argparse
import asyncio import asyncio
import codecs
import json import json
import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from http import HTTPStatus import os
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import importlib
import inspect
from aioprometheus import MetricsMiddleware from prometheus_client import make_asgi_app
from aioprometheus.asgi.starlette import metrics
import fastapi import fastapi
import uvicorn import uvicorn
from http import HTTPStatus
from fastapi import Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
...@@ -21,27 +17,17 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response ...@@ -21,27 +17,17 @@ from fastapi.responses import JSONResponse, StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import add_global_metrics_labels from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.sampling_params import SamplingParams from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.entrypoints.openai.serving_engine import LoRA
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None
engine_args = None
engine = None
response_role = None
@asynccontextmanager @asynccontextmanager
...@@ -61,6 +47,16 @@ async def lifespan(app: fastapi.FastAPI): ...@@ -61,6 +47,16 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan) app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
...@@ -81,12 +77,28 @@ def parse_args(): ...@@ -81,12 +77,28 @@ def parse_args():
type=json.loads, type=json.loads,
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. If not " help="The model name used in the API. If not "
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.") "the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=str,
default=None, default=None,
...@@ -111,81 +123,31 @@ def parse_args(): ...@@ -111,81 +123,31 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics # Add prometheus asgi middleware to route /metrics requests
app.add_route("/metrics", metrics) # Exposes HTTP metrics metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
def load_chat_template(args, tokenizer):
if args.chat_template is not None:
try:
with open(args.chat_template, "r") as f:
chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
chat_template = codecs.decode(args.chat_template, "unicode_escape")
tokenizer.chat_template = chat_template
logger.info(
f"Using supplied chat template:\n{tokenizer.chat_template}")
elif tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None
) -> Tuple[List[int], Optional[JSONResponse]]:
assert (not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided."
input_ids = prompt_ids if prompt_ids is not None else tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/health") @app.get("/health")
...@@ -196,544 +158,37 @@ async def health() -> Response: ...@@ -196,544 +158,37 @@ async def health() -> Response:
@app.get("/v1/models") @app.get("/v1/models")
async def show_available_models(): async def show_available_models():
"""Show available models. Right now we only have one model.""" models = await openai_serving_chat.show_available_models()
model_cards = [ return JSONResponse(content=models.model_dump())
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def create_logprobs(
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
"""Completion API similar to OpenAI's API. generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
See https://platform.openai.com/docs/api-reference/chat/create if isinstance(generator, ErrorResponse):
for the API specification. This API mimics the OpenAI ChatCompletion API. return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
try:
prompt = tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(f"Error in applying chat template from request: {str(e)}")
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
def get_role() -> str:
if request.add_generation_prompt:
return response_role
else:
return request.messages[-1]["role"]
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# Send first response for each request.n (index) with the role
role = get_role()
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=DeltaMessage(role=role), finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=[], finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.json(exclude_unset=True,
exclude_none=True,
ensure_ascii=False)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def completion_full_generator():
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = get_role()
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
# Streaming response
if request.stream: if request.stream:
return StreamingResponse(completion_stream_generator(), return StreamingResponse(content=generator,
media_type="text/event-stream") media_type="text/event-stream")
else: else:
return await completion_full_generator() return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
"""Completion API similar to OpenAI's API. generator = await openai_serving_completion.create_completion(
request, raw_request)
See https://platform.openai.com/docs/api-reference/completions/create if isinstance(generator, ErrorResponse):
for the API specification. This API mimics the OpenAI Completion API. return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
NOTE: Currently we do not support the following features:
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic())
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
min_p=request.min_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens
if not echo_without_generation else 1,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
prompt_logprobs=request.logprobs if request.echo else None,
skip_special_tokens=request.skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if use_token_ids:
result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
has_echoed = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None:
top_logprobs = output.logprobs[previous_num_tokens[i]:]
else:
top_logprobs = None
offsets = len(previous_texts[i])
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
if request.logprobs is not None:
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets,
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
usage=final_usage,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.logprobs is not None:
if not echo_without_generation:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.echo:
token_ids = prompt_token_ids + token_ids
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
logprobs = create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
if not echo_without_generation:
output_text = output.text
if request.echo:
output_text = prompt_text + output_text
else:
output_text = prompt_text
choice_data = CompletionResponseChoice(
index=output.index,
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream: if request.stream:
# When user requests streaming but we don't stream, we still need to return StreamingResponse(content=generator,
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream") media_type="text/event-stream")
else:
return response return JSONResponse(content=generator.model_dump())
if __name__ == "__main__": if __name__ == "__main__":
...@@ -747,6 +202,29 @@ if __name__ == "__main__": ...@@ -747,6 +202,29 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers, allow_headers=args.allowed_headers,
) )
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
logger.info(f"args: {args}") logger.info(f"args: {args}")
if args.served_model_name is not None: if args.served_model_name is not None:
...@@ -754,22 +232,14 @@ if __name__ == "__main__": ...@@ -754,22 +232,14 @@ if __name__ == "__main__":
else: else:
served_model = args.model served_model = args.model
response_role = args.response_role
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) openai_serving_chat = OpenAIServingChat(engine, served_model,
max_model_len = engine_model_config.max_model_len args.response_role,
args.lora_modules,
# A separate tokenizer to map token IDs to strings. args.chat_template)
tokenizer = get_tokenizer( openai_serving_completion = OpenAIServingCompletion(
engine_model_config.tokenizer, engine, served_model, args.lora_modules)
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
load_chat_template(args, tokenizer)
# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)
app.root_path = args.root_path app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment