Commit ad58e9b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev

parents 408f663a 9ba0817f
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import openai import openai
import pytest
import requests import requests
from openai.types.completion import Completion from openai.types.completion import Completion
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
get_open_port, is_hip)
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage, from amdsmi import (amdsmi_get_gpu_vram_usage,
...@@ -356,12 +358,23 @@ def error_on_warning(): ...@@ -356,12 +358,23 @@ def error_on_warning():
yield yield
def get_physical_device_indices(devices):
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
return devices
visible_indices = [int(x) for x in visible_devices.split(",")]
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
return [index_mapping[i] for i in devices if i in index_mapping]
@_nvml() @_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int], def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int, threshold_bytes: int,
timeout_s: float = 120) -> None: timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda # Use nvml instead of pytorch to reduce measurement error from torch cuda
# context. # context.
devices = get_physical_device_indices(devices)
start_time = time.time() start_time = time.time()
while True: while True:
output: Dict[int, str] = {} output: Dict[int, str] = {}
...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test( ...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def multi_gpu_test(*, num_gpus: int):
"""
Decorate a test to be run only when multiple GPUs are available.
"""
test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus")
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_selector(test_skipif(fork_new_process_for_each_test(f)))
return wrapper
async def completions_with_server_args( async def completions_with_server_args(
prompts: List[str], prompts: List[str],
model_name: str, model_name: str,
......
...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, ...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int, def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_tokens: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor, sampled_token_ids: torch.Tensor,
slot_mapping: torch.Tensor, input_positions: torch.Tensor,
block_tables: torch.Tensor) -> None: seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner""" """Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size, return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
input_tokens, sampled_token_ids, block_size, input_tokens,
input_positions, seq_lens, slot_mapping, sampled_token_ids,
block_tables) input_positions, seq_lens,
slot_mapping, block_tables)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:
return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)
# trans_w16 # trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor, def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
......
...@@ -83,7 +83,9 @@ class AttentionBackend(ABC): ...@@ -83,7 +83,9 @@ class AttentionBackend(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
def advance_step(self, num_seqs: int, num_queries: int): def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -122,6 +122,40 @@ def _( ...@@ -122,6 +122,40 @@ def _(
return torch.empty_like(decode_query) return torch.empty_like(decode_query)
@torch.library.custom_op("vllm::reshape_and_cache_flash",
mutates_args=["kv_cache"])
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
k_scale, v_scale)
@reshape_and_cache_flash.register_fake # type: ignore
def _(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
pass
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -346,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -346,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions, input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor, seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping,
block_tables=self.block_tables) block_tables=self.block_tables)
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
ops.reshape_and_cache_flash( torch.ops.vllm.reshape_and_cache_flash(
key, key,
value, value,
key_cache, kv_cache,
value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, k_scale,
...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None if (kv_cache is None or prefill_meta.block_tables is None
...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
out = torch.ops.vllm.flash_attn_varlen_func( prefill_output = torch.ops.vllm.flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
output[: prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa q=query,
q=query, k=key_cache,
k=key_cache, v=value_cache,
v=value_cache, cu_seqlens_q=prefill_meta.query_start_loc,
cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len,
max_seqlen_q=prefill_meta.max_query_len, cu_seqlens_k=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, max_seqlen_k=max_seq_len,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
).squeeze(1) )
# Reshape the output tensor. if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, ...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None
# used for GPU in-place advance_step
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr: # An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8] # request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7] # request 2, page indices [1, 6, 7]
...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1 batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0 assert batch_size >= 0
# We will use flash attention for profiling to # We will use flash attention for profiling to
...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward() self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
...@@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata):
self.num_qo_heads, self.num_kv_heads, self.head_dim, self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size) self.page_size)
else: else:
if not self.use_cuda_graph: assert self.paged_kv_indices is not None
assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None
assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device)
self.device) # handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None assert self.decode_wrapper is not None
self.decode_wrapper.end_forward() self.decode_wrapper.end_forward()
...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
return self return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
"""
Update metadata in-place to advance one decode step.
"""
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
assert sampled_token_ids is not None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=model_input.input_tokens,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
block_table_bound=self.block_table_bound)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indptr: List[int] = [0] self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request # paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = [] self.paged_kv_last_page_len: List[int] = []
self.total_blocks = 0
self.is_profile_run: bool = False self.is_profile_run: bool = False
def _add_seq_group( def _add_seq_group(
...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# block_table_bound is 1 with 1 valid block. # block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16, # If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block. # block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \ block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \ if seq_len % self.block_size != 0 \
else seq_len // self.block_size else seq_len // self.block_size
...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to( block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True) device, non_blocking=True)
...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
out=query_start_loc[1:]) out=query_start_loc[1:])
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu", device="cpu",
dtype=torch.int) dtype=torch.int)
...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int) dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor( paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int) self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else: else:
paged_kv_indices_tensor = None paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor, paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads( num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config), self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads( num_kv_heads=self.runner.model_config.get_num_kv_heads(
......
...@@ -869,6 +869,13 @@ class ParallelConfig: ...@@ -869,6 +869,13 @@ class ParallelConfig:
f"distributed executor backend " f"distributed executor backend "
f"'{self.distributed_executor_backend}'.") f"'{self.distributed_executor_backend}'.")
if current_platform.is_tpu() and self.world_size > 1:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
if self.distributed_executor_backend != "ray":
raise ValueError(
"TPU backend only supports Ray for distributed inference.")
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
...@@ -876,7 +883,7 @@ class ParallelConfig: ...@@ -876,7 +883,7 @@ class ParallelConfig:
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if (torch.cuda.is_available() if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size): and cuda_device_count_stateless() < self.world_size):
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
......
...@@ -843,6 +843,13 @@ class EngineArgs: ...@@ -843,6 +843,13 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config() model_config = self.create_model_config()
if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len
...@@ -874,7 +881,10 @@ class EngineArgs: ...@@ -874,7 +881,10 @@ class EngineArgs:
# If not explicitly set, enable chunked prefill by default for # If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the # long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase. # initial memory profiling phase.
if use_long_context:
# Chunked prefill is currently disabled for multimodal models by
# default.
if use_long_context and not model_config.is_multimodal_model:
is_gpu = device_config.device_type == "cuda" is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
...@@ -1035,7 +1045,6 @@ class EngineArgs: ...@@ -1035,7 +1045,6 @@ class EngineArgs:
@dataclass @dataclass
class AsyncEngineArgs(EngineArgs): class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
@staticmethod @staticmethod
...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs): ...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs):
async_args_only: bool = False) -> FlexibleArgumentParser: async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only: if not async_args_only:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='Use Ray to start the LLM engine in a '
'separate process as the server process.'
'(DEPRECATED. This argument is deprecated '
'and will be removed in a future update. '
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
'use it. See '
'https://github.com/vllm-project/vllm/issues/7045.'
')')
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
......
...@@ -4,22 +4,18 @@ from functools import partial ...@@ -4,22 +4,18 @@ from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
PromptComponents, SchedulerOutputState)
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, from vllm.executor.ray_utils import initialize_ray_cluster
SingletonPromptInputs) from vllm.inputs import PromptInputs
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -30,7 +26,6 @@ from vllm.sampling_params import SamplingParams ...@@ -30,7 +26,6 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -404,139 +399,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -404,139 +399,6 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
...@@ -554,12 +416,13 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -554,12 +416,13 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -590,9 +453,6 @@ class AsyncLLMEngine: ...@@ -590,9 +453,6 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as distributed execution. Should be the same as
`parallel_config.worker_use_ray`. `parallel_config.worker_use_ray`.
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
...@@ -604,41 +464,23 @@ class AsyncLLMEngine: ...@@ -604,41 +464,23 @@ class AsyncLLMEngine:
def __init__(self, def __init__(self,
worker_use_ray: bool, worker_use_ray: bool,
engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs) self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs # This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed, # so the append to asyncio queues is not delayed,
# especially for multi-step. # especially for multi-step.
# #
# TODO: Currently, disabled for engine_use_ray, ask self.use_process_request_outputs_callback = True
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback: if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \ self.engine.process_request_outputs_callback = \
self.process_request_outputs self.process_request_outputs
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045.")
if envs.VLLM_ALLOW_ENGINE_USE_RAY:
print_warning_once(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
else:
raise ValueError("`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it")
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
# task as well to prevent it from being garbage # task as well to prevent it from being garbage
...@@ -725,16 +567,11 @@ class AsyncLLMEngine: ...@@ -725,16 +567,11 @@ class AsyncLLMEngine:
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
executor_class.uses_ray, executor_class.uses_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
...@@ -777,10 +614,6 @@ class AsyncLLMEngine: ...@@ -777,10 +614,6 @@ class AsyncLLMEngine:
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer: ) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
return await (self.engine.get_tokenizer_group(). return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request)) get_lora_tokenizer_async(lora_request))
...@@ -814,26 +647,6 @@ class AsyncLLMEngine: ...@@ -814,26 +647,6 @@ class AsyncLLMEngine:
self._background_loop_unshielded = None self._background_loop_unshielded = None
self.background_loop = None self.background_loop = None
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
engine_class = self._engine_class
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self, virtual_engine: int) -> bool: async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
...@@ -844,13 +657,8 @@ class AsyncLLMEngine: ...@@ -844,13 +657,8 @@ class AsyncLLMEngine:
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
try: try:
if self.engine_use_ray: await self.engine.add_request_async(**new_request)
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e: except ValueError as e:
# TODO: use a vLLM specific error for failed validation # TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception( self._request_tracker.process_exception(
...@@ -862,10 +670,7 @@ class AsyncLLMEngine: ...@@ -862,10 +670,7 @@ class AsyncLLMEngine:
if aborted_requests: if aborted_requests:
await self._engine_abort(aborted_requests) await self._engine_abort(aborted_requests)
if self.engine_use_ray: request_outputs = await self.engine.step_async(virtual_engine)
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
# If used as a callback, then already invoked inside # If used as a callback, then already invoked inside
...@@ -891,16 +696,10 @@ class AsyncLLMEngine: ...@@ -891,16 +696,10 @@ class AsyncLLMEngine:
return all_finished return all_finished
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray: self.engine.abort_request(request_ids)
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
if self.engine_use_ray: pipeline_parallel_size = \
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size
while True: while True:
...@@ -912,12 +711,7 @@ class AsyncLLMEngine: ...@@ -912,12 +711,7 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that # timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages, # they can process any other queued control plane messages,
# such as add/remove lora adapters. # such as add/remove lora adapters.
if self.engine_use_ray: await self.engine.stop_remote_worker_execution_loop_async()
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!") logger.debug("Got new requests!")
requests_in_progress = [ requests_in_progress = [
...@@ -938,17 +732,9 @@ class AsyncLLMEngine: ...@@ -938,17 +732,9 @@ class AsyncLLMEngine:
for task in done: for task in done:
result = task.result() result = task.result()
virtual_engine = requests_in_progress.index(task) virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray: has_unfinished_requests = (
has_unfinished_requests = ( self.engine.has_unfinished_requests_for_virtual_engine(
await (self.engine. virtual_engine))
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests: if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = ( requests_in_progress[virtual_engine] = (
asyncio.create_task( asyncio.create_task(
...@@ -1190,52 +976,29 @@ class AsyncLLMEngine: ...@@ -1190,52 +976,29 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
if self.engine_use_ray: return self.engine.get_model_config()
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig: async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine.""" """Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray: return self.engine.get_parallel_config()
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray: return self.engine.get_decoding_config()
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig: async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine.""" """Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray: return self.engine.get_scheduler_config()
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig: async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine.""" """Get the lora configuration of the vLLM engine."""
if self.engine_use_ray: return self.engine.get_lora_config()
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None: model_output: Optional[List[SamplerOutput]] = None) -> None:
if self.engine_use_ray: self.engine.do_log_stats()
await self.engine.do_log_stats.remote( # type: ignore
scheduler_outputs, model_output)
else:
self.engine.do_log_stats()
async def check_health(self) -> None: async def check_health(self) -> None:
"""Raises an error if engine is unhealthy.""" """Raises an error if engine is unhealthy."""
...@@ -1244,40 +1007,30 @@ class AsyncLLMEngine: ...@@ -1244,40 +1007,30 @@ class AsyncLLMEngine:
if self.is_stopped: if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.") raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray: await self.engine.check_health_async()
try:
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t) logger.debug("Health check took %fs", time.perf_counter() - t)
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
if self.engine_use_ray: return self.engine.is_tracing_enabled()
return await self.engine.is_tracing_enabled.remote( # type: ignore
)
else:
return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray: self.engine.add_logger(logger_name=logger_name, logger=logger)
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None: def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray: self.engine.remove_logger(logger_name=logger_name)
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None: async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile") # using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile") # using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
...@@ -3,13 +3,13 @@ import time ...@@ -3,13 +3,13 @@ import time
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union from typing import Set, Type, Union
import torch import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
...@@ -26,20 +26,19 @@ from vllm.engine.output_processor.interfaces import ( ...@@ -26,20 +26,19 @@ from vllm.engine.output_processor.interfaces import (
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs, InputRegistry, LLMInputs, PromptInputs)
SingletonPromptInputs) from vllm.inputs.preprocess import InputPreprocessor
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
...@@ -75,11 +74,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: ...@@ -75,11 +74,6 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
PromptComponents = Tuple[Optional[str], List[int],
Optional[MultiModalDataDict]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional[MultiModalDataDict]]
@dataclass @dataclass
class SchedulerOutputState: class SchedulerOutputState:
...@@ -225,9 +219,6 @@ class LLMEngine: ...@@ -225,9 +219,6 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None: ) -> None:
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "
...@@ -295,7 +286,6 @@ class LLMEngine: ...@@ -295,7 +286,6 @@ class LLMEngine:
self.observability_config = observability_config or ObservabilityConfig( self.observability_config = observability_config or ObservabilityConfig(
) )
self.log_stats = log_stats self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
...@@ -317,6 +307,9 @@ class LLMEngine: ...@@ -317,6 +307,9 @@ class LLMEngine:
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
model_config) model_config)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
self.input_registry = input_registry self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor( self.input_processor = input_registry.create_input_processor(
model_config) model_config)
...@@ -583,19 +576,15 @@ class LLMEngine: ...@@ -583,19 +576,15 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group( def get_tokenizer_group(
self, self,
group_type: Type[_G] = BaseTokenizerGroup, group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G: ) -> _G:
tokenizer_group = self.tokenizer tokenizer_group = self.tokenizer
if tokenizer_group is None: if tokenizer_group is None:
raise ValueError(missing_msg) raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type): if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. " raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but " f"Expected type: {group_type}, but "
...@@ -627,52 +616,6 @@ class LLMEngine: ...@@ -627,52 +616,6 @@ class LLMEngine:
self.prompt_adapter_config.verify_with_model_config( self.prompt_adapter_config.verify_with_model_config(
self.model_config) self.model_config)
def _get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def _get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self._get_bos_token_id()
return dec_start_token_id
def _add_processed_request( def _add_processed_request(
self, self,
request_id: str, request_id: str,
...@@ -687,7 +630,7 @@ class LLMEngine: ...@@ -687,7 +630,7 @@ class LLMEngine:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request, prompt_adapter_request)
...@@ -737,334 +680,6 @@ class LLMEngine: ...@@ -737,334 +680,6 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
_LLMInputComponentsType = Tuple[str, List[int]]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self._get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
'''
Wrapper around application of the model's tokenizer.
Arguments:
* prompt
* request_id
* lora_request
Returns:
* prompt token ids
'''
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = self._tokenize_prompt(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self._get_bos_token_id()
assert bos_token_id is not None
return [bos_token_id]
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`LLMInputs` instance
'''
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def process_model_inputs(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
...@@ -1123,12 +738,13 @@ class LLMEngine: ...@@ -1123,12 +738,13 @@ class LLMEngine:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = self.process_model_inputs( preprocessed_inputs = self.input_preprocessor.preprocess(
inputs, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -1281,7 +897,7 @@ class LLMEngine: ...@@ -1281,7 +897,7 @@ class LLMEngine:
ctx: The virtual engine context to work on ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed request_id: If provided, then only this request is going to be processed
""" """
now = time.time() now = time.time()
...@@ -1386,7 +1002,8 @@ class LLMEngine: ...@@ -1386,7 +1002,8 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) if request_output:
ctx.request_outputs.append(request_output)
# When we process a single request, we skip it for the next time, # When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output) # and invoke the request output callback (if there was final output)
...@@ -1423,14 +1040,19 @@ class LLMEngine: ...@@ -1423,14 +1040,19 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now) seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished() request_output = RequestOutputFactory.create(seq_group)
if self.step_return_finished_only else True): if request_output:
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) ctx.request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups: for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue
request_output = RequestOutputFactory.create(seq_group) request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output) if request_output:
ctx.request_outputs.append(request_output)
# Immediately process request outputs here (if callback is given) # Immediately process request outputs here (if callback is given)
if (ctx.request_outputs if (ctx.request_outputs
...@@ -1443,7 +1065,8 @@ class LLMEngine: ...@@ -1443,7 +1065,8 @@ class LLMEngine:
# LLMEngine/AsyncLLMEngine directly # LLMEngine/AsyncLLMEngine directly
if is_async: if is_async:
# Log stats. # Log stats.
self.do_log_stats(scheduler_outputs, outputs, finished_before) self.do_log_stats(scheduler_outputs, outputs, finished_before,
skip)
# Tracing # Tracing
self.do_tracing(scheduler_outputs) self.do_tracing(scheduler_outputs)
...@@ -1750,18 +1373,20 @@ class LLMEngine: ...@@ -1750,18 +1373,20 @@ class LLMEngine:
def do_log_stats(self, def do_log_stats(self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> None: finished_before: Optional[List[int]] = None,
skip: Optional[List[int]] = None) -> None:
"""Forced log when no requests active.""" """Forced log when no requests active."""
if self.log_stats: if self.log_stats:
stats = self._get_stats(scheduler_outputs, model_output, stats = self._get_stats(scheduler_outputs, model_output,
finished_before) finished_before, skip)
for logger in self.stat_loggers.values(): for logger in self.stat_loggers.values():
logger.log(stats) logger.log(stats)
def _get_stats(self, def _get_stats(self,
scheduler_outputs: Optional[SchedulerOutputs], scheduler_outputs: Optional[SchedulerOutputs],
model_output: Optional[List[SamplerOutput]] = None, model_output: Optional[List[SamplerOutput]] = None,
finished_before: Optional[List[int]] = None) -> Stats: finished_before: Optional[List[int]] = None,
skip: Optional[List[int]] = None) -> Stats:
"""Get Stats to be Logged to Prometheus. """Get Stats to be Logged to Prometheus.
Args: Args:
...@@ -1769,6 +1394,10 @@ class LLMEngine: ...@@ -1769,6 +1394,10 @@ class LLMEngine:
the scheduled batch, the scheduled batch,
model_output: Optional, used to emit speculative decoding metrics model_output: Optional, used to emit speculative decoding metrics
which are created by the workers. which are created by the workers.
finished_before: Optional, indices of sequences that were finished
before. These sequences will be ignored.
skip: Optional, indices of sequences that were preempted. These
sequences will be ignored.
""" """
now = time.time() now = time.time()
...@@ -1843,6 +1472,11 @@ class LLMEngine: ...@@ -1843,6 +1472,11 @@ class LLMEngine:
actual_num_batched_tokens -= 1 actual_num_batched_tokens -= 1
continue continue
# Currently, skip == preempted sequences, so we need to skip
# their log stats
if skip and idx in skip:
continue
group_was_prefill = idx < scheduler_outputs.num_prefill_groups group_was_prefill = idx < scheduler_outputs.num_prefill_groups
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
...@@ -1972,10 +1606,20 @@ class LLMEngine: ...@@ -1972,10 +1606,20 @@ class LLMEngine:
self.model_executor.check_health() self.model_executor.check_health()
def start_profile(self) -> None: def start_profile(self) -> None:
self.model_executor.start_profile() # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.model_executor.stop_profile() # using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
return self.tracer is not None return self.tracer is not None
...@@ -2049,7 +1693,7 @@ class LLMEngine: ...@@ -2049,7 +1693,7 @@ class LLMEngine:
metrics.model_execute_time) metrics.model_execute_time)
def is_encoder_decoder_model(self): def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self): def is_embedding_model(self):
return self.model_config.is_embedding_model return self.model_config.is_embedding_model
......
...@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions ...@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -642,14 +642,12 @@ class LLM: ...@@ -642,14 +642,12 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request " raise ValueError("The lengths of prompts and lora_request "
"must be the same.") "must be the same.")
if isinstance(params, list): for sp in params if isinstance(params, list) else (params, ):
params = [ if isinstance(sp, SamplingParams):
self._add_guided_processor(param, guided_options) self._add_guided_processor(sp, guided_options)
if isinstance(param, SamplingParams) else param
for param in params # We only care about the final output
] sp.output_kind = RequestOutputKind.FINAL_ONLY
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
...@@ -709,9 +707,6 @@ class LLM: ...@@ -709,9 +707,6 @@ class LLM:
f"output: {0:.2f} toks/s"), f"output: {0:.2f} toks/s"),
) )
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
...@@ -724,6 +719,7 @@ class LLM: ...@@ -724,6 +719,7 @@ class LLM:
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"] in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum( total_out_toks += sum(
...@@ -735,9 +731,6 @@ class LLM: ...@@ -735,9 +731,6 @@ class LLM:
f"output: {out_spd:.2f} toks/s") f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.
......
...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict ...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -195,7 +195,6 @@ async def main(args): ...@@ -195,7 +195,6 @@ async def main(args):
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config() model_config = await engine.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
......
...@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
return self.response_role return self.response_role
else: return request.messages[-1]["role"]
return request.messages[-1]["role"]
async def chat_completion_stream_generator( async def chat_completion_stream_generator(
self, self,
...@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing): ...@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index) # Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser( tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
try: try:
async for res in result_generator: async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
# We need to do it here, because if there are exceptions in # We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST # the result_generator, it needs to be sent as the FIRST
# response (by the try...catch). # response (by the try...catch).
...@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
# if continuous usage stats are requested, add it # if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats: if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(
usage = UsageInfo(prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=0, completion_tokens=0,
total_tokens=prompt_tokens) total_tokens=num_prompt_tokens)
chunk.usage = usage chunk.usage = usage
# otherwise don't # otherwise don't
else: else:
...@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request.stream_options.include_usage): request.stream_options.include_usage):
if (request.stream_options. if (request.stream_options.
continuous_usage_stats): continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=0, completion_tokens=0,
total_tokens=prompt_tokens) total_tokens=num_prompt_tokens)
chunk.usage = usage chunk.usage = usage
else: else:
chunk.usage = None chunk.usage = None
...@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing): ...@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False first_iteration = False
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index
if finish_reason_sent[i]: if finish_reason_sent[i]:
continue continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs and request.top_logprobs is not None: if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, ( assert output.logprobs is not None, (
"Did not output logprobs") "Did not output logprobs")
logprobs = self._create_chat_logprobs( logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids, token_ids=output.token_ids,
top_logprobs=out_logprobs, top_logprobs=output.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs, num_output_top_logprobs=request.top_logprobs,
) )
else: else:
logprobs = None logprobs = None
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text
delta_message: Optional[DeltaMessage] = None delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is if tool_choice_function_name:
ChatCompletionNamedToolChoiceParam):
delta_message = DeltaMessage(tool_calls=[ delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall( DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name, name=tool_choice_function_name,
arguments=delta_text), arguments=delta_text),
index=i) index=i)
]) ])
# handle streaming deltas for tools with "auto" tool choice # handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request) elif tool_choice_auto:
and tool_parser): assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = ( delta_message = (
tool_parser.extract_tool_calls_streaming( tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i], previous_text=previous_text,
current_text=output.text, current_text=current_text,
delta_text=delta_text, delta_text=delta_text,
previous_token_ids= \ previous_token_ids=previous_token_ids,
output.token_ids[ current_token_ids=current_token_ids,
:-1 * len(delta_token_ids) delta_token_ids=output.token_ids))
],
current_token_ids=output.token_ids, # update the previous values for the next iteration
delta_token_ids=delta_token_ids previous_texts[i] = current_text
) all_previous_token_ids[i] = current_token_ids
)
# handle streaming just a content delta # handle streaming just a content delta
else: else:
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration # set the previous values for the next iteration
previous_texts[i] = output.text previous_num_tokens[i] += len(output.token_ids)
previous_num_tokens[i] = len(output.token_ids)
# if the message delta is None (e.g. because it was a # if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise # "control token" for tool calls or the parser otherwise
...@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous # handle usage stats if requested & if continuous
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats): if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + total_tokens=num_prompt_tokens +
completion_tokens, completion_tokens,
) )
chunk.usage = usage chunk.usage = usage
...@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser.prev_tool_call_arr[index].get( tool_parser.prev_tool_call_arr[index].get(
"arguments", {})) "arguments", {}))
# get what we've streamed so for for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool
actual_call = tool_parser.streamed_args_for_tool[ actual_call = tool_parser.streamed_args_for_tool[
index] index]
...@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
]) ])
# Send the finish response for each request.n only once # Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=delta_message, delta=delta_message,
...@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model=model_name) model=model_name)
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats): if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=prompt_tokens + total_tokens=num_prompt_tokens +
completion_tokens, completion_tokens,
) )
chunk.usage = usage chunk.usage = usage
...@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage # is sent, send the usage
if (request.stream_options if (request.stream_options
and request.stream_options.include_usage): and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo( final_usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=previous_num_tokens[i], completion_tokens=completion_tokens,
total_tokens=prompt_tokens + previous_num_tokens[i], total_tokens=num_prompt_tokens + completion_tokens,
) )
final_usage_chunk = ChatCompletionStreamResponse( final_usage_chunk = ChatCompletionStreamResponse(
...@@ -607,7 +625,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -607,7 +625,7 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
if not (self.enable_auto_tools if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance( or not self.tool_parser) and not isinstance(
request.tool_choice, request.tool_choice,
ChatCompletionNamedToolChoiceParam): ChatCompletionNamedToolChoiceParam):
...@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or "") or "")
choice.message.content = full_message choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids) num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum( num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs) len(output.token_ids) for output in final_res.outputs)
...@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return bool( return bool(
# if there is a delta message that includes tool calls which # if there is a delta message that includes tool calls which
# include a function that has arguments # include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0] and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
) )
...@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
...@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = res.prompt_logprobs prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int] delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[ out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]] int, Logprob]]]]
...@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = prompt_text delta_text = prompt_text
...@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
assert prompt_logprobs is not None assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
...@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text
delta_token_ids = output.token_ids[ delta_token_ids = output.token_ids
previous_num_tokens[i]:] out_logprobs = output.logprobs
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert out_logprobs is not None, ( assert out_logprobs is not None, (
...@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]), initial_text_offset=previous_text_lens[i],
) )
else: else:
logprobs = None logprobs = None
previous_texts[i] = output.text previous_text_lens[i] += len(output.text)
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason finish_reason = output.finish_reason
stop_reason = output.stop_reason stop_reason = output.stop_reason
...@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(prompt_token_ids) prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = len(output.token_ids) completion_tokens = previous_num_tokens[i]
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
...@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch: for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
...@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
) )
choices.append(choice_data) choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids) num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
......
...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser): ...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser):
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
......
...@@ -61,7 +61,6 @@ if TYPE_CHECKING: ...@@ -61,7 +61,6 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
...@@ -409,14 +408,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -409,14 +408,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_RPC_GET_DATA_TIMEOUT_MS": "VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY":
lambda:
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas. # a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded # if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded # if this is set to an empty string, no plugins will be loaded
......
...@@ -5,7 +5,8 @@ from typing_extensions import TypeIs ...@@ -5,7 +5,8 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs) LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -60,8 +61,38 @@ def parse_and_batch_prompt( ...@@ -60,8 +61,38 @@ def parse_and_batch_prompt(
for elem in prompt for elem in prompt
] ]
raise ValueError("prompt must be a string, array of strings, " raise TypeError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays") "array of tokens, or array of token arrays")
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
def parse_singleton_prompt(
inputs: SingletonPromptInputs,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore
elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=inputs)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
......
import asyncio
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict
logger = init_logger(__name__)
PromptComponents = Tuple[Optional[str], List[int],
Optional["MultiModalDataDict"]]
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
Optional["MultiModalDataDict"]]
class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup],
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError("You cannot pass text prompts when "
"`skip_tokenizer_init` is True")
return self.tokenizer
def get_bos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for BOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id
def get_eos_token_id(self,
lora_request: Optional[LoRARequest] = None
) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def get_decoder_start_token_id(self) -> Optional[int]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id = self.get_bos_token_id()
assert bos_token_id is not None
return [bos_token_id]
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self.get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _apply_prompt_adapter(
self,
prompt_token_ids: List[int],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> List[int]:
if prompt_adapter_request:
prompt_token_ids = (
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
+ prompt_token_ids)
return prompt_token_ids
def _tokenize_prompt(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""
Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
return tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = parsed["content"].get("multi_modal_data")
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs(
self,
encoder_comps: PromptComponents,
decoder_comps: DecoderPromptComponents,
) -> EncoderDecoderLLMInputs:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
)
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
decoder_input,
request_id=request_id,
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
def _build_decoder_only_llm_inputs(
self,
prompt_comps: PromptComponents,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> LLMInputs:
prompt, prompt_token_ids, multi_modal_data = prompt_comps
prompt_token_ids = self._apply_prompt_adapter(
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data)
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
Arguments:
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
Returns:
* :class:`LLMInputs` instance
'''
prompt_comps = self._extract_prompt_components(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
def preprocess(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Preprocess the input prompt."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
async def preprocess_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`preprocess`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
...@@ -410,6 +410,7 @@ def fused_topk( ...@@ -410,6 +410,7 @@ def fused_topk(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -443,7 +444,8 @@ def grouped_topk(hidden_states: torch.Tensor,
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
return topk_weights, topk_ids.to(torch.int32)
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
......
...@@ -990,7 +990,7 @@ def get_rope( ...@@ -990,7 +990,7 @@ def get_rope(
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "mrope": elif scaling_type == "mrope":
return MRotaryEmbedding( rotary_emb = MRotaryEmbedding(
head_size, head_size,
rotary_dim, rotary_dim,
max_position, max_position,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment