Commit ad58e9b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev

parents 408f663a 9ba0817f
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import openai import openai
import pytest
import requests import requests
from openai.types.completion import Completion from openai.types.completion import Completion
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
get_open_port, is_hip)
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage, from amdsmi import (amdsmi_get_gpu_vram_usage,
...@@ -356,12 +358,23 @@ def error_on_warning(): ...@@ -356,12 +358,23 @@ def error_on_warning():
yield yield
def get_physical_device_indices(devices):
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
return devices
visible_indices = [int(x) for x in visible_devices.split(",")]
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
return [index_mapping[i] for i in devices if i in index_mapping]
@_nvml() @_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int], def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int, threshold_bytes: int,
timeout_s: float = 120) -> None: timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda # Use nvml instead of pytorch to reduce measurement error from torch cuda
# context. # context.
devices = get_physical_device_indices(devices)
start_time = time.time() start_time = time.time()
while True: while True:
output: Dict[int, str] = {} output: Dict[int, str] = {}
...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test( ...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def multi_gpu_test(*, num_gpus: int):
"""
Decorate a test to be run only when multiple GPUs are available.
"""
test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus")
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_selector(test_skipif(fork_new_process_for_each_test(f)))
return wrapper
async def completions_with_server_args( async def completions_with_server_args(
prompts: List[str], prompts: List[str],
model_name: str, model_name: str,
......
...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, ...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int, def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_tokens: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor, sampled_token_ids: torch.Tensor,
slot_mapping: torch.Tensor, input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None: block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner""" """Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size, return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
input_tokens, sampled_token_ids, block_size, input_tokens,
input_positions, seq_lens, slot_mapping, sampled_token_ids,
block_tables) input_positions, seq_lens,
slot_mapping, block_tables)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:
return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)
# trans_w16 # trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor, def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
......
...@@ -83,7 +83,9 @@ class AttentionBackend(ABC): ...@@ -83,7 +83,9 @@ class AttentionBackend(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
def advance_step(self, num_seqs: int, num_queries: int): def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -122,6 +122,40 @@ def _( ...@@ -122,6 +122,40 @@ def _(
return torch.empty_like(decode_query) return torch.empty_like(decode_query)
@torch.library.custom_op("vllm::reshape_and_cache_flash",
mutates_args=["kv_cache"])
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
k_scale, v_scale)
@reshape_and_cache_flash.register_fake # type: ignore
def _(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
pass
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -346,7 +380,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -346,7 +380,7 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=model_input.input_tokens,
...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
ops.reshape_and_cache_flash( torch.ops.vllm.reshape_and_cache_flash(
key, key,
value, value,
key_cache, kv_cache,
value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, k_scale,
...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None if (kv_cache is None or prefill_meta.block_tables is None
...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
out = torch.ops.vllm.flash_attn_varlen_func( prefill_output = torch.ops.vllm.flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -701,14 +736,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -701,14 +736,11 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
output[: prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -725,8 +757,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -725,8 +757,7 @@ class FlashAttentionImpl(AttentionImpl):
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output[ decode_output = torch.ops.vllm.flash_attn_with_kvcache(
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1), decode_query.unsqueeze(1),
key_cache, key_cache,
value_cache, value_cache,
...@@ -738,5 +769,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -738,5 +769,11 @@ class FlashAttentionImpl(AttentionImpl):
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
).squeeze(1) ).squeeze(1)
# Reshape the output tensor. if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, ...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None
# used for GPU in-place advance_step
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr: # An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8] # request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7] # request 2, page indices [1, 6, 7]
...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1 batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0 assert batch_size >= 0
# We will use flash attention for profiling to # We will use flash attention for profiling to
...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward() self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
...@@ -335,7 +344,6 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -335,7 +344,6 @@ class FlashInferMetadata(AttentionMetadata):
self.num_qo_heads, self.num_kv_heads, self.head_dim, self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size) self.page_size)
else: else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
...@@ -343,6 +351,11 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -343,6 +351,11 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
# handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None assert self.decode_wrapper is not None
self.decode_wrapper.end_forward() self.decode_wrapper.end_forward()
...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
return self return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
"""
Update metadata in-place to advance one decode step.
"""
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
assert sampled_token_ids is not None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=model_input.input_tokens,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
block_table_bound=self.block_table_bound)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indptr: List[int] = [0] self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request # paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = [] self.paged_kv_last_page_len: List[int] = []
self.total_blocks = 0
self.is_profile_run: bool = False self.is_profile_run: bool = False
def _add_seq_group( def _add_seq_group(
...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# block_table_bound is 1 with 1 valid block. # block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16, # If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block. # block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \ block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \ if seq_len % self.block_size != 0 \
else seq_len // self.block_size else seq_len // self.block_size
...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to( block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True) device, non_blocking=True)
...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
out=query_start_loc[1:]) out=query_start_loc[1:])
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu", device="cpu",
dtype=torch.int) dtype=torch.int)
...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int) dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor( paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int) self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else: else:
paged_kv_indices_tensor = None paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor, paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads( num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config), self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads( num_kv_heads=self.runner.model_config.get_num_kv_heads(
......
...@@ -869,6 +869,13 @@ class ParallelConfig: ...@@ -869,6 +869,13 @@ class ParallelConfig:
f"distributed executor backend " f"distributed executor backend "
f"'{self.distributed_executor_backend}'.") f"'{self.distributed_executor_backend}'.")
if current_platform.is_tpu() and self.world_size > 1:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
if self.distributed_executor_backend != "ray":
raise ValueError(
"TPU backend only supports Ray for distributed inference.")
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
...@@ -876,7 +883,7 @@ class ParallelConfig: ...@@ -876,7 +883,7 @@ class ParallelConfig:
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if (torch.cuda.is_available() if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size): and cuda_device_count_stateless() < self.world_size):
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
......
...@@ -843,6 +843,13 @@ class EngineArgs: ...@@ -843,6 +843,13 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config() model_config = self.create_model_config()
if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len
...@@ -874,7 +881,10 @@ class EngineArgs: ...@@ -874,7 +881,10 @@ class EngineArgs:
# If not explicitly set, enable chunked prefill by default for # If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the # long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase. # initial memory profiling phase.
if use_long_context:
# Chunked prefill is currently disabled for multimodal models by
# default.
if use_long_context and not model_config.is_multimodal_model:
is_gpu = device_config.device_type == "cuda" is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
...@@ -1035,7 +1045,6 @@ class EngineArgs: ...@@ -1035,7 +1045,6 @@ class EngineArgs:
@dataclass @dataclass
class AsyncEngineArgs(EngineArgs): class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
@staticmethod @staticmethod
...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs): ...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs):
async_args_only: bool = False) -> FlexibleArgumentParser: async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only: if not async_args_only:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='Use Ray to start the LLM engine in a '
'separate process as the server process.'
'(DEPRECATED. This argument is deprecated '
'and will be removed in a future update. '
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
'use it. See '
'https://github.com/vllm-project/vllm/issues/7045.'
')')
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
......
...@@ -4,22 +4,18 @@ from functools import partial ...@@ -4,22 +4,18 @@ from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union) Mapping, Optional, Set, Tuple, Type, Union)
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig) ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine, from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
PromptComponents, SchedulerOutputState)
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, from vllm.executor.ray_utils import initialize_ray_cluster
SingletonPromptInputs) from vllm.inputs import PromptInputs
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -30,7 +26,6 @@ from vllm.sampling_params import SamplingParams ...@@ -30,7 +26,6 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -404,139 +399,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -404,139 +399,6 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
async def _tokenize_prompt_async(
self,
prompt: str,
request_id: str,
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")
return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
prompt = None
prompt_token_ids = inputs["prompt_token_ids"]
else:
# NOTE: This extra assignment is required to pass mypy
prompt = parsed_prompt = inputs["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
parsed_prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = inputs.get("multi_modal_data")
else:
assert_never(inputs)
return prompt, prompt_token_ids, multi_modal_data
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
decoder_task = self._extract_prompt_components_async(
decoder_input,
request_id=request_id,
)
encoder_comps, decoder_comps = await asyncio.gather(
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
)
decoder_comps = None, None, None
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
request_id=request_id,
lora_request=lora_request,
)
return self._build_decoder_only_llm_inputs(
prompt_comps,
prompt_adapter_request=prompt_adapter_request,
)
async def process_model_inputs_async(
self,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
"""Async version of :meth:`process_model_inputs`."""
if self.is_encoder_decoder_model():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs = await self._process_encoder_decoder_prompt_async(
inputs,
request_id=request_id,
)
else:
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
model_inputs = await self._process_decoder_only_prompt_async(
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
return self.input_processor(model_inputs)
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
...@@ -554,12 +416,13 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -554,12 +416,13 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
processed_inputs = await self.process_model_inputs_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -590,9 +453,6 @@ class AsyncLLMEngine: ...@@ -590,9 +453,6 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as distributed execution. Should be the same as
`parallel_config.worker_use_ray`. `parallel_config.worker_use_ray`.
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
...@@ -604,41 +464,23 @@ class AsyncLLMEngine: ...@@ -604,41 +464,23 @@ class AsyncLLMEngine:
def __init__(self, def __init__(self,
worker_use_ray: bool, worker_use_ray: bool,
engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs) self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs # This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed, # so the append to asyncio queues is not delayed,
# especially for multi-step. # especially for multi-step.
# #
# TODO: Currently, disabled for engine_use_ray, ask self.use_process_request_outputs_callback = True
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback: if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \ self.engine.process_request_outputs_callback = \
self.process_request_outputs self.process_request_outputs
if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045.")
if envs.VLLM_ALLOW_ENGINE_USE_RAY:
print_warning_once(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray")
else:
raise ValueError("`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it")
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded # We need to keep a reference to unshielded
# task as well to prevent it from being garbage # task as well to prevent it from being garbage
...@@ -725,16 +567,11 @@ class AsyncLLMEngine: ...@@ -725,16 +567,11 @@ class AsyncLLMEngine:
# Create the engine configs. # Create the engine configs.
engine_config = engine_args.create_engine_config() engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
executor_class = cls._get_executor_cls(engine_config) executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
executor_class.uses_ray, executor_class.uses_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
...@@ -777,10 +614,6 @@ class AsyncLLMEngine: ...@@ -777,10 +614,6 @@ class AsyncLLMEngine:
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer: ) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
return await (self.engine.get_tokenizer_group(). return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request)) get_lora_tokenizer_async(lora_request))
...@@ -814,26 +647,6 @@ class AsyncLLMEngine: ...@@ -814,26 +647,6 @@ class AsyncLLMEngine:
self._background_loop_unshielded = None self._background_loop_unshielded = None
self.background_loop = None self.background_loop = None
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
if not self.engine_use_ray:
engine_class = self._engine_class
elif self.worker_use_ray:
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self, virtual_engine: int) -> bool: async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests. """Kick the engine to process the waiting requests.
...@@ -844,12 +657,7 @@ class AsyncLLMEngine: ...@@ -844,12 +657,7 @@ class AsyncLLMEngine:
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
try: try:
if self.engine_use_ray:
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request) await self.engine.add_request_async(**new_request)
except ValueError as e: except ValueError as e:
# TODO: use a vLLM specific error for failed validation # TODO: use a vLLM specific error for failed validation
...@@ -862,9 +670,6 @@ class AsyncLLMEngine: ...@@ -862,9 +670,6 @@ class AsyncLLMEngine:
if aborted_requests: if aborted_requests:
await self._engine_abort(aborted_requests) await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async(virtual_engine) request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams. # Put the outputs into the corresponding streams.
...@@ -891,15 +696,9 @@ class AsyncLLMEngine: ...@@ -891,15 +696,9 @@ class AsyncLLMEngine:
return all_finished return all_finished
async def _engine_abort(self, request_ids: Iterable[str]): async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids) self.engine.abort_request(request_ids)
async def run_engine_loop(self): async def run_engine_loop(self):
if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \ pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size
...@@ -912,11 +711,6 @@ class AsyncLLMEngine: ...@@ -912,11 +711,6 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that # timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages, # they can process any other queued control plane messages,
# such as add/remove lora adapters. # such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async() await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests() await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!") logger.debug("Got new requests!")
...@@ -938,16 +732,8 @@ class AsyncLLMEngine: ...@@ -938,16 +732,8 @@ class AsyncLLMEngine:
for task in done: for task in done:
result = task.result() result = task.result()
virtual_engine = requests_in_progress.index(task) virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = ( has_unfinished_requests = (
self.engine. self.engine.has_unfinished_requests_for_virtual_engine(
has_unfinished_requests_for_virtual_engine(
virtual_engine)) virtual_engine))
if result or has_unfinished_requests: if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = ( requests_in_progress[virtual_engine] = (
...@@ -1190,51 +976,28 @@ class AsyncLLMEngine: ...@@ -1190,51 +976,28 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config() return self.engine.get_model_config()
async def get_parallel_config(self) -> ParallelConfig: async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine.""" """Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config() return self.engine.get_parallel_config()
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_decoding_config.remote( # type: ignore
)
else:
return self.engine.get_decoding_config() return self.engine.get_decoding_config()
async def get_scheduler_config(self) -> SchedulerConfig: async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine.""" """Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config() return self.engine.get_scheduler_config()
async def get_lora_config(self) -> LoRAConfig: async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine.""" """Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config() return self.engine.get_lora_config()
async def do_log_stats( async def do_log_stats(
self, self,
scheduler_outputs: Optional[SchedulerOutputs] = None, scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None: model_output: Optional[List[SamplerOutput]] = None) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote( # type: ignore
scheduler_outputs, model_output)
else:
self.engine.do_log_stats() self.engine.do_log_stats()
async def check_health(self) -> None: async def check_health(self) -> None:
...@@ -1244,40 +1007,30 @@ class AsyncLLMEngine: ...@@ -1244,40 +1007,30 @@ class AsyncLLMEngine:
if self.is_stopped: if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.") raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async() await self.engine.check_health_async()
logger.debug("Health check took %fs", time.perf_counter() - t) logger.debug("Health check took %fs", time.perf_counter() - t)
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
if self.engine_use_ray:
return await self.engine.is_tracing_enabled.remote( # type: ignore
)
else:
return self.engine.is_tracing_enabled() return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger) self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None: def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name) self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None: async def start_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile") self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile") self.engine.model_executor._run_workers("stop_profile")
This diff is collapsed.
...@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions ...@@ -19,7 +19,7 @@ from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
...@@ -642,14 +642,12 @@ class LLM: ...@@ -642,14 +642,12 @@ class LLM:
raise ValueError("The lengths of prompts and lora_request " raise ValueError("The lengths of prompts and lora_request "
"must be the same.") "must be the same.")
if isinstance(params, list): for sp in params if isinstance(params, list) else (params, ):
params = [ if isinstance(sp, SamplingParams):
self._add_guided_processor(param, guided_options) self._add_guided_processor(sp, guided_options)
if isinstance(param, SamplingParams) else param
for param in params # We only care about the final output
] sp.output_kind = RequestOutputKind.FINAL_ONLY
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, request_inputs in enumerate(inputs):
...@@ -709,9 +707,6 @@ class LLM: ...@@ -709,9 +707,6 @@ class LLM:
f"output: {0:.2f} toks/s"), f"output: {0:.2f} toks/s"),
) )
# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0 total_in_toks = 0
...@@ -724,6 +719,7 @@ class LLM: ...@@ -724,6 +719,7 @@ class LLM:
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids) total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"] in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum( total_out_toks += sum(
...@@ -735,9 +731,6 @@ class LLM: ...@@ -735,9 +731,6 @@ class LLM:
f"output: {out_spd:.2f} toks/s") f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
# Restore original behavior
self.llm_engine.step_return_finished_only = False
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
# Sort the outputs by request ID. # Sort the outputs by request ID.
......
...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict ...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -195,7 +195,6 @@ async def main(args): ...@@ -195,7 +195,6 @@ async def main(args):
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config() model_config = await engine.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
......
This diff is collapsed.
...@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try: try:
async for prompt_idx, res in result_generator: async for prompt_idx, res in result_generator:
...@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = res.prompt_logprobs prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int] delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[ out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]] int, Logprob]]]]
...@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo and request.max_tokens == 0: if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
# only return the prompt # only return the prompt
delta_text = prompt_text delta_text = prompt_text
...@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
elif (request.echo and request.max_tokens > 0 elif (request.echo and request.max_tokens > 0
and not has_echoed[i]): and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None assert prompt_text is not None
assert prompt_logprobs is not None assert prompt_logprobs is not None
# echo the prompt and first token # echo the prompt and first token
...@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True has_echoed[i] = True
else: else:
# return just the delta # return just the delta
delta_text = output.text[len(previous_texts[i]):] delta_text = output.text
delta_token_ids = output.token_ids[ delta_token_ids = output.token_ids
previous_num_tokens[i]:] out_logprobs = output.logprobs
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert out_logprobs is not None, ( assert out_logprobs is not None, (
...@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs, top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer, tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]), initial_text_offset=previous_text_lens[i],
) )
else: else:
logprobs = None logprobs = None
previous_texts[i] = output.text previous_text_lens[i] += len(output.text)
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason finish_reason = output.finish_reason
stop_reason = output.stop_reason stop_reason = output.stop_reason
...@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage): and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None): or output.finish_reason is not None):
prompt_tokens = len(prompt_token_ids) prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = len(output.token_ids) completion_tokens = previous_num_tokens[i]
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
...@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch: for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt prompt_text = final_res.prompt
...@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
) )
choices.append(choice_data) choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids) num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
......
...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser): ...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser):
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
......
...@@ -61,7 +61,6 @@ if TYPE_CHECKING: ...@@ -61,7 +61,6 @@ if TYPE_CHECKING:
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_TEST_FORCE_FP8_MARLIN: bool = False
VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000 VLLM_RPC_GET_DATA_TIMEOUT_MS: int = 5000
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
...@@ -409,14 +408,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -409,14 +408,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_RPC_GET_DATA_TIMEOUT_MS": "VLLM_RPC_GET_DATA_TIMEOUT_MS":
lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")), lambda: int(os.getenv("VLLM_RPC_GET_DATA_TIMEOUT_MS", "5000")),
# If set, allow running the engine as a separate ray actor,
# which is a deprecated feature soon to be removed.
# See https://github.com/vllm-project/vllm/issues/7045
"VLLM_ALLOW_ENGINE_USE_RAY":
lambda:
(os.environ.get("VLLM_ALLOW_ENGINE_USE_RAY", "0").strip().lower() in
("1", "true")),
# a list of plugin names to load, separated by commas. # a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded # if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded # if this is set to an empty string, no plugins will be loaded
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment