Commit ad58e9b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev

parents 408f663a 9ba0817f
...@@ -10,6 +10,7 @@ from pathlib import Path ...@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import openai import openai
import pytest
import requests import requests
from openai.types.completion import Completion from openai.types.completion import Completion
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -22,7 +23,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
get_open_port, is_hip)
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage, from amdsmi import (amdsmi_get_gpu_vram_usage,
...@@ -356,12 +358,23 @@ def error_on_warning(): ...@@ -356,12 +358,23 @@ def error_on_warning():
yield yield
def get_physical_device_indices(devices):
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if visible_devices is None:
return devices
visible_indices = [int(x) for x in visible_devices.split(",")]
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
return [index_mapping[i] for i in devices if i in index_mapping]
@_nvml() @_nvml()
def wait_for_gpu_memory_to_clear(devices: List[int], def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int, threshold_bytes: int,
timeout_s: float = 120) -> None: timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda # Use nvml instead of pytorch to reduce measurement error from torch cuda
# context. # context.
devices = get_physical_device_indices(devices)
start_time = time.time() start_time = time.time()
while True: while True:
output: Dict[int, str] = {} output: Dict[int, str] = {}
...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test( ...@@ -441,6 +454,22 @@ def fork_new_process_for_each_test(
return wrapper return wrapper
def multi_gpu_test(*, num_gpus: int):
"""
Decorate a test to be run only when multiple GPUs are available.
"""
test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus")
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_selector(test_skipif(fork_new_process_for_each_test(f)))
return wrapper
async def completions_with_server_args( async def completions_with_server_args(
prompts: List[str], prompts: List[str],
model_name: str, model_name: str,
......
...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor, ...@@ -251,16 +251,36 @@ def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int, def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, input_tokens: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor, sampled_token_ids: torch.Tensor,
slot_mapping: torch.Tensor, input_positions: torch.Tensor,
block_tables: torch.Tensor) -> None: seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner""" """Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size, return torch.ops._C.advance_step_flashattn(num_seqs, num_queries,
input_tokens, sampled_token_ids, block_size, input_tokens,
input_positions, seq_lens, slot_mapping, sampled_token_ids,
block_tables) input_positions, seq_lens,
slot_mapping, block_tables)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:
return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)
# trans_w16 # trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor, def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
......
...@@ -83,7 +83,9 @@ class AttentionBackend(ABC): ...@@ -83,7 +83,9 @@ class AttentionBackend(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
def advance_step(self, num_seqs: int, num_queries: int): def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -122,6 +122,40 @@ def _( ...@@ -122,6 +122,40 @@ def _(
return torch.empty_like(decode_query) return torch.empty_like(decode_query)
@torch.library.custom_op("vllm::reshape_and_cache_flash",
mutates_args=["kv_cache"])
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
"""Inductor cannot deal with inplace operations on views.
See https://github.com/pytorch/pytorch/issues/131192
and https://github.com/pytorch/pytorch/issues/130174
This is a workaround to hide the view operation from the inductor.
"""
return torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype,
k_scale, v_scale)
@reshape_and_cache_flash.register_fake # type: ignore
def _(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
) -> None:
pass
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
...@@ -346,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -346,15 +380,15 @@ class FlashAttentionMetadata(AttentionMetadata):
self.seq_lens[i] += 1 self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens) self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step(num_seqs=num_seqs, ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries, num_queries=num_queries,
block_size=block_size, block_size=block_size,
input_tokens=model_input.input_tokens, input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions, input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor, seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping,
block_tables=self.block_tables) block_tables=self.block_tables)
class FlashAttentionMetadataBuilder( class FlashAttentionMetadataBuilder(
...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -653,11 +687,10 @@ class FlashAttentionImpl(AttentionImpl):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run. # not cached. This happens during the initial memory profiling run.
ops.reshape_and_cache_flash( torch.ops.vllm.reshape_and_cache_flash(
key, key,
value, value,
key_cache, kv_cache,
value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, k_scale,
...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -669,7 +702,6 @@ class FlashAttentionImpl(AttentionImpl):
assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached. # Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:] decode_query = query[num_prefill_tokens:]
# QKV for prefill. # QKV for prefill.
...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -680,6 +712,9 @@ class FlashAttentionImpl(AttentionImpl):
assert query.shape[0] == num_prefill_tokens assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens assert decode_query.shape[0] == num_decode_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if (kv_cache is None or prefill_meta.block_tables is None if (kv_cache is None or prefill_meta.block_tables is None
...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -687,7 +722,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
out = torch.ops.vllm.flash_attn_varlen_func( prefill_output = torch.ops.vllm.flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -701,42 +736,44 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
output[: prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa
num_prefill_tokens] = torch.ops.vllm.flash_attn_varlen_func( # noqa q=query,
q=query, k=key_cache,
k=key_cache, v=value_cache,
v=value_cache, cu_seqlens_q=prefill_meta.query_start_loc,
cu_seqlens_q=prefill_meta.query_start_loc, max_seqlen_q=prefill_meta.max_query_len,
max_seqlen_q=prefill_meta.max_query_len, cu_seqlens_k=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc, max_seqlen_k=max_seq_len,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap,
)
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output[
num_prefill_tokens:] = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=self.logits_soft_cap, softcap=self.logits_soft_cap,
).squeeze(1) )
# Reshape the output tensor. if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query.unsqueeze(1),
key_cache,
value_cache,
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
).squeeze(1)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, ...@@ -30,7 +30,8 @@ from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None
# used for GPU in-place advance_step
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None
# An example for paged_kv_indices, paged_kv_indptr: # An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8] # request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7] # request 2, page indices [1, 6, 7]
...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -318,6 +323,8 @@ class FlashInferMetadata(AttentionMetadata):
assert self.paged_kv_indices is not None assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1 batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0 assert batch_size >= 0
# We will use flash attention for profiling to # We will use flash attention for profiling to
...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -327,6 +334,8 @@ class FlashInferMetadata(AttentionMetadata):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device) self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward() self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward( self.prefill_wrapper.begin_forward(
...@@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -335,14 +344,18 @@ class FlashInferMetadata(AttentionMetadata):
self.num_qo_heads, self.num_kv_heads, self.head_dim, self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size) self.page_size)
else: else:
if not self.use_cuda_graph: assert self.paged_kv_indices is not None
assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None
assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None
assert self.paged_kv_last_page_len is not None self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device)
self.device) # handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
assert self.decode_wrapper is not None assert self.decode_wrapper is not None
self.decode_wrapper.end_forward() self.decode_wrapper.end_forward()
...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata): ...@@ -391,6 +404,48 @@ class FlashInferMetadata(AttentionMetadata):
return self return self
def advance_step(
self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
):
"""
Update metadata in-place to advance one decode step.
"""
assert num_seqs > 0
assert num_queries > 0
assert model_input.attn_metadata is not None
assert sampled_token_ids is not None
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
model_input.input_tokens[:num_queries] = sampled_token_ids.flatten()
# Update GPU tensors
ops.advance_step_flashinfer(
num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=model_input.input_tokens,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables,
paged_kv_indices=self.paged_kv_indices,
paged_kv_indptr=self.paged_kv_indptr,
paged_kv_last_page_len=self.paged_kv_last_page_len,
block_table_bound=self.block_table_bound)
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -428,7 +483,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_indptr: List[int] = [0] self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request # paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = [] self.paged_kv_last_page_len: List[int] = []
self.total_blocks = 0
self.is_profile_run: bool = False self.is_profile_run: bool = False
def _add_seq_group( def _add_seq_group(
...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -499,6 +554,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# block_table_bound is 1 with 1 valid block. # block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16, # If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block. # block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \ block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \ if seq_len % self.block_size != 0 \
else seq_len // self.block_size else seq_len // self.block_size
...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -541,9 +597,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size] input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]
block_tables = torch.from_numpy(input_block_tables).to( block_tables = torch.from_numpy(input_block_tables).to(
device, non_blocking=True) device, non_blocking=True)
...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -583,6 +649,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
out=query_start_loc[1:]) out=query_start_loc[1:])
if len(self.paged_kv_indptr) > 0: if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu", device="cpu",
dtype=torch.int) dtype=torch.int)
...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -591,10 +661,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
dtype=torch.int) dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor( paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int) self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else: else:
paged_kv_indices_tensor = None paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -613,6 +688,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor, paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads( num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config), self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads( num_kv_heads=self.runner.model_config.get_num_kv_heads(
......
...@@ -869,6 +869,13 @@ class ParallelConfig: ...@@ -869,6 +869,13 @@ class ParallelConfig:
f"distributed executor backend " f"distributed executor backend "
f"'{self.distributed_executor_backend}'.") f"'{self.distributed_executor_backend}'.")
if current_platform.is_tpu() and self.world_size > 1:
if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray"
if self.distributed_executor_backend != "ray":
raise ValueError(
"TPU backend only supports Ray for distributed inference.")
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
...@@ -876,7 +883,7 @@ class ParallelConfig: ...@@ -876,7 +883,7 @@ class ParallelConfig:
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray_is_available() ray_found = ray_utils.ray_is_available()
if (torch.cuda.is_available() if (current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size): and cuda_device_count_stateless() < self.world_size):
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
......
...@@ -843,6 +843,13 @@ class EngineArgs: ...@@ -843,6 +843,13 @@ class EngineArgs:
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config() model_config = self.create_model_config()
if model_config.is_multimodal_model:
if self.enable_prefix_caching:
logger.warning(
"--enable-prefix-caching is currently not "
"supported for multimodal models and has been disabled.")
self.enable_prefix_caching = False
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len
...@@ -874,7 +881,10 @@ class EngineArgs: ...@@ -874,7 +881,10 @@ class EngineArgs:
# If not explicitly set, enable chunked prefill by default for # If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the # long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase. # initial memory profiling phase.
if use_long_context:
# Chunked prefill is currently disabled for multimodal models by
# default.
if use_long_context and not model_config.is_multimodal_model:
is_gpu = device_config.device_type == "cuda" is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window() use_sliding_window = (model_config.get_sliding_window()
is not None) is not None)
...@@ -1035,7 +1045,6 @@ class EngineArgs: ...@@ -1035,7 +1045,6 @@ class EngineArgs:
@dataclass @dataclass
class AsyncEngineArgs(EngineArgs): class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
@staticmethod @staticmethod
...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs): ...@@ -1043,16 +1052,6 @@ class AsyncEngineArgs(EngineArgs):
async_args_only: bool = False) -> FlexibleArgumentParser: async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only: if not async_args_only:
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
action='store_true',
help='Use Ray to start the LLM engine in a '
'separate process as the server process.'
'(DEPRECATED. This argument is deprecated '
'and will be removed in a future update. '
'Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to force '
'use it. See '
'https://github.com/vllm-project/vllm/issues/7045.'
')')
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict ...@@ -12,7 +12,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -316,6 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -559,6 +562,8 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
) )
@model_validator(mode="before") @model_validator(mode="before")
......
...@@ -195,7 +195,6 @@ async def main(args): ...@@ -195,7 +195,6 @@ async def main(args):
engine = AsyncLLMEngine.from_engine_args( engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config() model_config = await engine.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
......
This diff is collapsed.
...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser): ...@@ -33,7 +33,6 @@ class Hermes2ProToolParser(ToolParser):
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: List[Dict] = [] self.prev_tool_call_arr: List[Dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
self.current_tool_name_sent = False
self.streamed_args_for_tool: List[str] = [ self.streamed_args_for_tool: List[str] = [
] # map what has been streamed for each tool so far to a list ] # map what has been streamed for each tool so far to a list
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment