Commit 53076d70 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.2' into v0.8.2-ori

parents 322a0be6 9c5c81b0
...@@ -3,8 +3,8 @@ import pytest ...@@ -3,8 +3,8 @@ import pytest
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
......
...@@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, ...@@ -4,12 +4,16 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionState, AttentionType) AttentionState, AttentionType)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
__all__ = [ __all__ = [
"Attention", "AttentionBackend", "AttentionMetadata", "AttentionType", "Attention",
"AttentionMetadataBuilder", "Attention", "AttentionState", "AttentionBackend",
"get_attn_backend", "get_flash_attn_version" "AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
] ]
...@@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): ...@@ -232,6 +232,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
class AttentionLayer(Protocol): class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor _k_scale: torch.Tensor
_v_scale: torch.Tensor _v_scale: torch.Tensor
_k_scale_float: float _k_scale_float: float
......
...@@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -19,10 +19,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
# yapf: enable # yapf: enable
from vllm.attention.backends.utils import ( from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version, compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty)
is_block_tables_empty) from vllm.fa_utils import get_flash_attn_version
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
...@@ -630,9 +630,12 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -630,9 +630,12 @@ class FlashAttentionImpl(AttentionImpl):
self.sliding_window = ((sliding_window - 1, self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1)) 0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype): self.vllm_flash_attn_version = get_flash_attn_version(
requires_alibi=self.alibi_slopes is not None)
if (is_quantized_kv_cache(self.kv_cache_dtype)
and self.vllm_flash_attn_version != 3):
raise NotImplementedError( raise NotImplementedError(
"FlashAttention with FP8 KV cache not yet supported") "Only FlashAttention3 supports FP8 KV cache")
if logits_soft_cap is None: if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap. # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0 logits_soft_cap = 0
...@@ -647,7 +650,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -647,7 +650,6 @@ class FlashAttentionImpl(AttentionImpl):
f"Head size {head_size} is not supported by FlashAttention. " f"Head size {head_size} is not supported by FlashAttention. "
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
self.attn_type = attn_type self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
def forward( def forward(
self, self,
...@@ -671,13 +673,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -671,13 +673,19 @@ class FlashAttentionImpl(AttentionImpl):
for profiling run. for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor. NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
"base dtype bfloat16")
attn_type = self.attn_type attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)): and (not attn_metadata.is_all_encoder_attn_metadata_set)):
...@@ -694,6 +702,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -694,6 +702,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size = self.sliding_window window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap logits_soft_cap: Optional[float] = self.logits_soft_cap
fp8_attention = kv_cache_dtype.startswith("fp8")
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key_cache = kv_cache[0] key_cache = kv_cache[0]
...@@ -729,6 +738,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -729,6 +738,19 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
if fp8_attention:
kv_cache = kv_cache.view(torch.float8_e4m3fn)
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
if fp8_attention:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
(num_prefill_query_tokens, num_prefill_kv_tokens, (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \ num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
...@@ -753,6 +775,23 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -753,6 +775,23 @@ class FlashAttentionImpl(AttentionImpl):
key = key[:num_prefill_kv_tokens] key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens]
if fp8_attention:
num_kv_tokens, num_kv_heads, head_size = key.shape
key, _ = ops.scaled_fp8_quant(
key.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._k_scale)
key = key.reshape((num_kv_tokens, num_kv_heads, head_size))
value, _ = ops.scaled_fp8_quant(
value.reshape((num_kv_tokens,
num_kv_heads * head_size)).contiguous(),
layer._v_scale)
value = value.reshape(
(num_kv_tokens, num_kv_heads, head_size))
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
...@@ -768,13 +807,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -768,13 +807,19 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching") "Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
assert prefill_meta.query_start_loc is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( # noqa flash_attn_varlen_func( # noqa
q=query, q=query,
k=key_cache, k=key_cache,
...@@ -791,6 +836,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -791,6 +836,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output, out=prefill_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
...@@ -804,6 +852,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -804,6 +852,9 @@ class FlashAttentionImpl(AttentionImpl):
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1" "Only decoder-only models support max_decode_query_len > 1"
) )
assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1])
flash_attn_varlen_func( flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
...@@ -820,6 +871,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -820,6 +871,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output, out=decode_output,
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
...@@ -828,6 +882,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -828,6 +882,7 @@ class FlashAttentionImpl(AttentionImpl):
_, _,
block_tables_arg, block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type) ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
flash_attn_with_kvcache( flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), q=decode_query.unsqueeze(1),
k_cache=key_cache, k_cache=key_cache,
...@@ -841,6 +896,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -841,6 +896,9 @@ class FlashAttentionImpl(AttentionImpl):
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=decode_output.unsqueeze(1), out=decode_output.unsqueeze(1),
fa_version=self.vllm_flash_attn_version, fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
) )
return output return output
......
...@@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, ...@@ -203,9 +203,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionState, MLAAttentionImpl) AttentionState, MLAAttentionImpl)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
get_flash_attn_version,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
import openvino as ov
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
src_offset: int, dst_offset: int) -> None:
def create_roi_tensor(
tensor: ov.Tensor,
block_number: int,
) -> ov.Tensor:
roi_begin = ov.runtime.Coordinate([0, 0, 0, 0])
roi_end = ov.runtime.Coordinate(tensor.get_shape())
roi_begin[0] = block_number
roi_end[0] = block_number + 1
if isinstance(tensor, ov.Tensor):
return ov.Tensor(tensor, roi_begin, roi_end)
else:
return ov.RemoteTensor(tensor, roi_begin, roi_end)
src_roi_tensor = \
create_roi_tensor(src_tensor, src_offset)
dst_roi_tensor = \
create_roi_tensor(dst_tensor, dst_offset)
src_roi_tensor.copy_to(dst_roi_tensor)
class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "OPENVINO"
@staticmethod
def get_impl_cls():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise NotImplementedError
@staticmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size, head_size)
@staticmethod
def swap_blocks(
src_tensor: ov.Tensor,
dst_tensor: ov.Tensor,
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
copy_cache_block(src_tensor, dst_tensor, src, dst)
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
copy_cache_block(key_cache, key_cache, src, dst)
copy_cache_block(value_cache, value_cache, src, dst)
@dataclass
class OpenVINOAttentionMetadata:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute​
- prompt_lens – per sequence size number of scheduled tokens​
- batch_size_in_tokens = sum(prompt_lens)​
- max_context_len = max(context_lens)​
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
- num_blocks – total number of blocks in block_indices​
"""
# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32​
past_lens: torch.Tensor
# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence​
# Shape: [batch_size_in_sequences + 1]​
# Type: i32
subsequence_begins: torch.Tensor
# Describes block tables for each sequence within a batch​ -
# indices along 0th dimension in key_cache and value_cache inputs​
# Shape: [num_blocks]
# Type: i32​
block_indices: torch.Tensor
# Describes block tables for each sequence within a batch​ -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence​
# Shape: [batch_size_in_sequences + 1]
# Type: i32​
block_indices_begins: torch.Tensor
# Describes max context length
# Shape: scalar
# Type: i32
max_context_len: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
...@@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union ...@@ -8,13 +8,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm import envs
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState) AttentionState)
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens( ...@@ -585,35 +583,3 @@ def get_num_prefill_decode_query_kv_tokens(
return (num_prefill_query_tokens, num_prefill_kv_tokens, return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) num_decode_query_tokens)
def get_flash_attn_version():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if current_platform.get_device_capability()[0] == 9:
fa_version = 3 if is_fa_version_supported(3) else 2
else:
fa_version = 2
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
fa_version = envs.VLLM_FLASH_ATTN_VERSION
if (current_platform.get_device_capability()[0] == 10
and envs.VLLM_FLASH_ATTN_VERSION == 3):
logger.warning("Cannot use FA version 3 on Blackwell platform",
"defaulting to FA version 2.")
fa_version = 2
if not is_fa_version_supported(fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
return fa_version
except (ImportError, AssertionError):
return None
...@@ -84,6 +84,9 @@ class Attention(nn.Module): ...@@ -84,6 +84,9 @@ class Attention(nn.Module):
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention # We also keep the float32 versions of k/v_scale for attention
# backends that don't support tensors (Flashinfer) # backends that don't support tensors (Flashinfer)
...@@ -153,6 +156,7 @@ class Attention(nn.Module): ...@@ -153,6 +156,7 @@ class Attention(nn.Module):
).parallel_config.pipeline_parallel_size) ).parallel_config.pipeline_parallel_size)
] ]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
...@@ -178,7 +182,7 @@ class Attention(nn.Module): ...@@ -178,7 +182,7 @@ class Attention(nn.Module):
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation: if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value) self.calc_kv_scales(query, key, value)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
...@@ -225,7 +229,8 @@ class Attention(nn.Module): ...@@ -225,7 +229,8 @@ class Attention(nn.Module):
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
def calc_kv_scales(self, key, value): def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._k_scale_float = self._k_scale.item() self._k_scale_float = self._k_scale.item()
...@@ -276,8 +281,7 @@ class MultiHeadAttention(nn.Module): ...@@ -276,8 +281,7 @@ class MultiHeadAttention(nn.Module):
backend = _Backend.XFORMERS backend = _Backend.XFORMERS
self.attn_backend = backend if backend in { self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
_Backend.XFORMERS,
} else _Backend.TORCH_SDPA } else _Backend.TORCH_SDPA
def forward( def forward(
...@@ -315,6 +319,13 @@ class MultiHeadAttention(nn.Module): ...@@ -315,6 +319,13 @@ class MultiHeadAttention(nn.Module):
value, value,
scale=self.scale) scale=self.scale)
out = out.transpose(1, 2) out = out.transpose(1, 2)
elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
query, key, value = (x.transpose(1, 2)
for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
......
...@@ -357,6 +357,11 @@ class VllmBackend: ...@@ -357,6 +357,11 @@ class VllmBackend:
# graph. # graph.
factors = [] factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the # 1. factors come from the vllm_config (it mainly summarizes how the
# model is created) # model is created)
config_hash = vllm_config.compute_hash() config_hash = vllm_config.compute_hash()
...@@ -399,6 +404,7 @@ class VllmBackend: ...@@ -399,6 +404,7 @@ class VllmBackend:
rank = vllm_config.parallel_config.rank rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE
......
...@@ -246,6 +246,7 @@ class ModelConfig: ...@@ -246,6 +246,7 @@ class ModelConfig:
max_seq_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20, max_logprobs: int = 20,
disable_sliding_window: bool = False, disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None, served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
...@@ -322,6 +323,7 @@ class ModelConfig: ...@@ -322,6 +323,7 @@ class ModelConfig:
self.max_seq_len_to_capture = max_seq_len_to_capture self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode self.enable_sleep_mode = enable_sleep_mode
...@@ -670,14 +672,6 @@ class ModelConfig: ...@@ -670,14 +672,6 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
MODEL_NOT_SUPPORT_CUDA_GRAPH = ['mllama']
if (self.hf_config.model_type in MODEL_NOT_SUPPORT_CUDA_GRAPH
and not self.enforce_eager):
logger.warning(
"CUDA graph is not supported for %s yet, fallback to the eager "
"mode.", self.hf_config.model_type)
self.enforce_eager = True
def _verify_bnb_config(self) -> None: def _verify_bnb_config(self) -> None:
""" """
The current version of bitsandbytes (0.44.0) with 8-bit models does not The current version of bitsandbytes (0.44.0) with 8-bit models does not
...@@ -1029,6 +1023,13 @@ class ModelConfig: ...@@ -1029,6 +1023,13 @@ class ModelConfig:
"max_new_tokens") "max_new_tokens")
else: else:
diff_sampling_param = {} diff_sampling_param = {}
if diff_sampling_param:
logger.warning_once(
"Default sampling parameters have been overridden by the "
"model's Hugging Face generation config recommended from the "
"model creator. If this is not intended, please relaunch "
"vLLM instance with `--generation-config vllm`.")
return diff_sampling_param return diff_sampling_param
@property @property
...@@ -1300,6 +1301,12 @@ class LoadConfig: ...@@ -1300,6 +1301,12 @@ class LoadConfig:
"tensorizer" will use CoreWeave's tensorizer library for "tensorizer" will use CoreWeave's tensorizer library for
fast weight loading. fast weight loading.
"bitsandbytes" will load nf4 type weights. "bitsandbytes" will load nf4 type weights.
"sharded_state" will load weights from pre-sharded checkpoint files,
supporting efficient loading of tensor-parallel models.
"gguf" will load weights from GGUF format files.
"mistral" will load weights from consolidated safetensors files used
by Mistral models.
"runai_streamer" will load weights from RunAI streamer format files.
model_loader_extra_config: The extra config for the model loader. model_loader_extra_config: The extra config for the model loader.
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
...@@ -1473,7 +1480,7 @@ class ParallelConfig: ...@@ -1473,7 +1480,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"] ray_only_devices: list[str] = []
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices if (current_platform.device_type in ray_only_devices
and self.world_size > 1): and self.world_size > 1):
...@@ -1801,7 +1808,7 @@ class DeviceConfig: ...@@ -1801,7 +1808,7 @@ class DeviceConfig:
self.device_type = device self.device_type = device
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron", "openvino"]: if self.device_type in ["neuron"]:
self.device = torch.device("cpu") self.device = torch.device("cpu")
elif self.device_type in ["tpu"]: elif self.device_type in ["tpu"]:
self.device = None self.device = None
...@@ -1810,12 +1817,139 @@ class DeviceConfig: ...@@ -1810,12 +1817,139 @@ class DeviceConfig:
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
@dataclass
class SpeculativeConfig: class SpeculativeConfig:
"""Configuration for speculative decoding. """
Configuration for speculative decoding.
Configurable parameters include:
- General Speculative Decoding Control:
- num_speculative_tokens (int): The number of speculative
tokens, if provided. It will default to the number in the draft
model config if present, otherwise, it is required.
- model (Optional[str]): The name of the draft model, eagle head,
or additional weights, if provided.
- method (Optional[str]): The name of the speculative method to use.
If users provide and set the `model` param, the speculative method
type will be detected automatically if possible, if `model` param
is not provided, the method name must be provided.
- Possible values:
- ngram
Related additional configuration:
- prompt_lookup_max (Optional[int]):
Maximum size of ngram token window when using Ngram
proposer, required when method is set to ngram.
- prompt_lookup_min (Optional[int]):
Minimum size of ngram token window when using Ngram
proposer, if provided. Defaults to 1.
- eagle
- medusa
- mlp_speculator
- draft_model
- acceptance_method (str): The method to use for accepting draft
tokens. This can take two possible values: 'rejection_sampler' and
'typical_acceptance_sampler' for RejectionSampler and
TypicalAcceptanceSampler respectively. If not specified, it
defaults to 'rejection_sampler'.
- Possible values:
- rejection_sampler
- typical_acceptance_sampler
Related additional configuration:
- posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the
posterior probability of a token in the target model
for it to be accepted. This threshold is used only
when we use the TypicalAcceptanceSampler for token
acceptance.
- posterior_alpha (Optional[float]):
Scaling factor for entropy-based threshold, applied
when using TypicalAcceptanceSampler.
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
parallelism for the draft model. Can only be 1 or the same as the
target model's tensor parallel size.
- disable_logprobs (bool): If set to True, token log probabilities are
not returned during speculative decoding. If set to False, token
log probabilities are returned according to the log probability
settings in SamplingParams. If not specified, it defaults to True.
- Draft Model Configuration:
- quantization (Optional[str]): Quantization method that was used to
quantize the draft model weights. If None, we assume the
model weights are not quantized. Note that it only takes effect
when using the draft model-based speculative method.
- max_model_len (Optional[int]): The maximum model length of the
draft model. Used when testing the ability to skip
speculation for some sequences.
- revision: The specific model version to use for the draft model. It
can be a branch name, a tag name, or a commit id. If unspecified,
will use the default version.
- code_revision: The specific revision to use for the draft model code
on Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
The configuration is currently specialized to draft-model speculative - Advanced Control:
decoding with top-1 proposals. - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
batch expansion for scoring proposals. If not specified, it
defaults to False.
- disable_by_batch_size (Optional[int]): Disable speculative decoding
for new incoming requests when the number of enqueued requests is
larger than this value, if provided.
Although the parameters above are structured hierarchically, there is no
need to nest them during configuration.
Non-configurable internal parameters include:
- Model Configuration:
- target_model_config (ModelConfig): The configuration of the target
model.
- draft_model_config (ModelConfig): The configuration of the draft
model initialized internal.
- Parallelism Configuration:
- target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
- draft_parallel_config (ParallelConfig): The parallel configuration
for the draft model initialized internal.
- Execution Control:
- enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since it's not
yet compatible with speculative decode.
- disable_log_stats (bool): Whether to disable the periodic printing of
stage times in speculative decoding.
""" """
# speculative configs from cli args
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
method: Optional[str] = None
acceptance_method: str = "rejection_sampler"
draft_tensor_parallel_size: Optional[int] = None
disable_logprobs: bool = True
model: Optional[str] = None
quantization: Optional[str] = None
max_model_len: Optional[int] = None
revision: Optional[str] = None
code_revision: Optional[str] = None
disable_mqa_scorer: bool = False
disable_by_batch_size: Optional[int] = None
prompt_lookup_max: Optional[int] = None
prompt_lookup_min: Optional[int] = None
posterior_threshold: Optional[float] = None
posterior_alpha: Optional[float] = None
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
disable_log_stats: bool = field(default=None, init=True) # type: ignore
# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
...@@ -1835,6 +1969,11 @@ class SpeculativeConfig: ...@@ -1835,6 +1969,11 @@ class SpeculativeConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest() hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str return hash_str
@classmethod
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
"""Parse the CLI value for the speculative config."""
return cls(**dict_value)
@staticmethod @staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3": if hf_config.model_type == "deepseek_v3":
...@@ -1847,230 +1986,172 @@ class SpeculativeConfig: ...@@ -1847,230 +1986,172 @@ class SpeculativeConfig:
}) })
return hf_config return hf_config
@staticmethod def __post_init__(self):
def maybe_create_spec_config(
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_disable_mqa_scorer: Optional[bool],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.
This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.
Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_model_quantization (Optional[str]): Quantization method
that was used to quantize the speculative model weights. If
None, we assume the model weights are not quantized.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided. Will default to the number in the draft
model config if present, otherwise is required.
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
scorer for the speculative model and fall back to batch
expansion for scoring.
speculative_max_model_len (Optional[int]): The maximum model len of
the speculative model. Used when testing the ability to skip
speculation for some sequences.
enable_chunked_prefill (bool): Whether vLLM is configured to use
chunked prefill or not. Used for raising an error since its not
yet compatible with spec decode.
speculative_disable_by_batch_size (Optional[int]): Disable
speculative decoding for new incoming requests when the number
of enqueue requests is larger than this value, if provided.
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
window, if provided.
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
window, if provided.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs (Optional[bool]): If set to True, token log
probabilities are not returned during speculative decoding.
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns: # Note: After next release, the method parameter will be used to
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if # specify the speculative method, which helps to extend the
the necessary conditions are met, else None. # configuration of non-model-based proposers, and the model parameter
""" # will be used when the draft model or head is needed.
if speculative_model is None: # If users do not specify the method, the speculative method will
if num_speculative_tokens is not None: # be detected automatically if possible. If the speculative method can
if target_model_config.hf_text_config.model_type \ # not be detected, it will be considered as the draft-model-based
# method by default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
# mtp acceleration for more models besides deepseek_v3
if self.target_model_config.hf_text_config.model_type \
== "deepseek_v3": == "deepseek_v3":
# use the draft model from the same model: # use the draft model from the same model:
speculative_model = target_model_config.model self.model = self.target_model_config.model
else: elif self.method in ("ngram", "[ngram]"):
raise ValueError( self.model = "ngram"
"num_speculative_tokens was provided without "
"speculative_model.")
else: else:
return None raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
if (speculative_disable_by_batch_size is not None
and speculative_disable_by_batch_size < 2): # Automatically configure the ngram method during configuration
raise ValueError("Expect the batch size threshold of disabling " # refactoring to ensure a smooth transition.
"speculative decoding is > 1, but got " if self.method is None and (self.model is not None
f"{speculative_disable_by_batch_size=}") and self.model in ("ngram", "[ngram]")):
if (enable_chunked_prefill and speculative_model == "eagle"): self.method = "ngram"
raise ValueError("Chunked prefill and EAGLE are not compatible.")
# TODO: The user should be able to specify revision/max model len if self.method in ("ngram", "[ngram]"):
# for the draft model. It is not currently supported. # Unified to "ngram" internally
draft_revision = None self.method = "ngram"
draft_code_revision = None # Set default values if not provided
draft_quantization = speculative_model_quantization if (self.prompt_lookup_min is None
and self.prompt_lookup_max is None):
if speculative_model == "[ngram]": # TODO(woosuk): Tune these values. They are arbitrarily chosen.
if ngram_prompt_lookup_min is None: self.prompt_lookup_min = 5
ngram_prompt_lookup_min = 1 self.prompt_lookup_max = 5
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: elif self.prompt_lookup_min is None:
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") assert self.prompt_lookup_max is not None
if ngram_prompt_lookup_min < 1: self.prompt_lookup_min = self.prompt_lookup_max
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") elif self.prompt_lookup_max is None:
if ngram_prompt_lookup_min > ngram_prompt_lookup_max: assert self.prompt_lookup_min is not None
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " self.prompt_lookup_max = self.prompt_lookup_min
f"larger than {ngram_prompt_lookup_max=}")
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model # TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set # config, in future, we may try refactor it out, and set
# draft related config as None here. # draft related config as None here.
draft_model_config = target_model_config self.draft_model_config = self.target_model_config
draft_parallel_config = target_parallel_config self.draft_parallel_config = self.target_parallel_config
else: else:
ngram_prompt_lookup_max = 0 self.prompt_lookup_max = 0
ngram_prompt_lookup_min = 0 self.prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model, if self.model is not None:
task="draft", self.draft_model_config = ModelConfig(
tokenizer=target_model_config.tokenizer, model=self.model,
tokenizer_mode=target_model_config.tokenizer_mode, task="draft",
trust_remote_code=target_model_config.trust_remote_code, tokenizer=self.target_model_config.tokenizer,
allowed_local_media_path=target_model_config. tokenizer_mode=self.target_model_config.tokenizer_mode,
allowed_local_media_path, trust_remote_code=self.target_model_config.
dtype=target_model_config.dtype, trust_remote_code,
seed=target_model_config.seed, allowed_local_media_path=self.target_model_config.
revision=draft_revision, allowed_local_media_path,
code_revision=draft_code_revision, dtype=self.target_model_config.dtype,
tokenizer_revision=target_model_config.tokenizer_revision, seed=self.target_model_config.seed,
max_model_len=None, revision=self.revision,
spec_target_max_model_len=target_model_config.max_model_len, code_revision=self.code_revision,
quantization=draft_quantization, tokenizer_revision=self.target_model_config.
enforce_eager=target_model_config.enforce_eager, tokenizer_revision,
max_seq_len_to_capture=target_model_config. max_model_len=None,
max_seq_len_to_capture, spec_target_max_model_len=self.target_model_config.
max_logprobs=target_model_config.max_logprobs, max_model_len,
hf_overrides=SpeculativeConfig.hf_config_override, quantization=self.quantization,
) enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
draft_hf_config = draft_model_config.hf_config max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model # Automatically detect the method
if "eagle-" in draft_model_config.model.lower(): if "eagle-" in self.draft_model_config.model.lower():
from vllm.transformers_utils.configs.eagle import EAGLEConfig self.method = "eagle"
if isinstance(draft_model_config.hf_config, EAGLEConfig): elif self.draft_model_config.hf_config.model_type == "medusa":
pass self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
else: else:
eagle_config = EAGLEConfig(draft_model_config.hf_config) self.method = "draft_model"
draft_model_config.hf_config = eagle_config
# Replace hf_config for EAGLE draft_model
if (num_speculative_tokens is not None if self.method == "eagle":
and hasattr(draft_hf_config, "num_lookahead_tokens")): if self.enable_chunked_prefill:
draft_hf_config.num_lookahead_tokens = num_speculative_tokens raise ValueError(
n_predict = getattr(draft_hf_config, "n_predict", None) "Chunked prefill and EAGLE are not compatible.")
if n_predict is not None:
if num_speculative_tokens is None: from vllm.transformers_utils.configs.eagle import (
# Default to max value defined in draft model config. EAGLEConfig)
num_speculative_tokens = n_predict if isinstance(self.draft_model_config.hf_config,
elif num_speculative_tokens > n_predict and \ EAGLEConfig):
num_speculative_tokens % n_predict != 0: pass
# Ensure divisibility for MTP module reuse. else:
raise ValueError( eagle_config = EAGLEConfig(
f"{num_speculative_tokens=} must be divisible by " self.draft_model_config.hf_config)
f"{n_predict=}") self.draft_model_config.hf_config = eagle_config
speculative_draft_tensor_parallel_size = \ if (self.num_speculative_tokens is not None
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( and hasattr(self.draft_model_config.hf_config,
target_parallel_config, "num_lookahead_tokens")):
speculative_draft_tensor_parallel_size, self.draft_model_config.hf_config.num_lookahead_tokens = \
draft_hf_config self.num_speculative_tokens
)
n_predict = getattr(self.draft_model_config.hf_config,
"n_predict", None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
draft_model_config.max_model_len = ( self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len( SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len, self.max_model_len,
draft_model_config.max_model_len, self.draft_model_config.max_model_len,
target_model_config.max_model_len, self.target_model_config.max_model_len,
)) ))
draft_parallel_config = ( self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config, self.target_parallel_config,
speculative_draft_tensor_parallel_size, draft_hf_config)) self.draft_tensor_parallel_size))
if num_speculative_tokens is None: if self.acceptance_method == "typical_acceptance_sampler":
raise ValueError( if self.posterior_threshold is None:
"num_speculative_tokens must be provided with " self.posterior_threshold = 0.09
"speculative_model unless the draft model config contains an " if self.posterior_alpha is None:
"n_predict parameter.") self.posterior_alpha = 0.3
if typical_acceptance_sampler_posterior_threshold is None: self._verify_args()
typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True
return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
speculative_disable_mqa_scorer,
speculative_disable_by_batch_size,
ngram_prompt_lookup_max,
ngram_prompt_lookup_min,
draft_token_acceptance_method=draft_token_acceptance_method,
typical_acceptance_sampler_posterior_threshold=\
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
@staticmethod @staticmethod
def _maybe_override_draft_max_model_len( def _maybe_override_draft_max_model_len(
...@@ -2108,7 +2189,7 @@ class SpeculativeConfig: ...@@ -2108,7 +2189,7 @@ class SpeculativeConfig:
) )
@staticmethod @staticmethod
def _verify_and_get_draft_model_tensor_parallel_size( def _verify_and_get_draft_tp(
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int], speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig) -> int: draft_hf_config: PretrainedConfig) -> int:
...@@ -2140,7 +2221,6 @@ class SpeculativeConfig: ...@@ -2140,7 +2221,6 @@ class SpeculativeConfig:
def create_draft_parallel_config( def create_draft_parallel_config(
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int, speculative_draft_tensor_parallel_size: int,
draft_hf_config: PretrainedConfig,
) -> ParallelConfig: ) -> ParallelConfig:
"""Create a parallel config for use by the draft worker. """Create a parallel config for use by the draft worker.
...@@ -2164,74 +2244,13 @@ class SpeculativeConfig: ...@@ -2164,74 +2244,13 @@ class SpeculativeConfig:
return draft_parallel_config return draft_parallel_config
def __init__(
self,
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
speculative_disable_mqa_scorer: Optional[bool],
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
):
"""Create a SpeculativeConfig object.
Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
speculative_disable_by_batch_size: Disable speculative
decoding for new incoming requests when the number of
enqueue requests is larger than this value.
ngram_prompt_lookup_max: Max size of ngram token window.
ngram_prompt_lookup_min: Min size of ngram token window.
draft_token_acceptance_method (str): The method to use for
accepting draft tokens. This can take two possible
values 'rejection_sampler' and 'typical_acceptance_sampler'
for RejectionSampler and TypicalAcceptanceSampler
respectively.
typical_acceptance_sampler_posterior_threshold (Optional[float]):
A threshold value that sets a lower bound on the posterior
probability of a token in the target model for it to be
accepted. This threshold is used only when we use the
TypicalAcceptanceSampler for token acceptance.
typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
self.speculative_disable_by_batch_size = \
speculative_disable_by_batch_size
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
self.draft_token_acceptance_method = draft_token_acceptance_method
self.typical_acceptance_sampler_posterior_threshold = \
typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
"speculative model unless the draft model config contains an "
"n_predict parameter.")
if self.num_speculative_tokens <= 0: if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater " raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).") f"than zero ({self.num_speculative_tokens}).")
...@@ -2241,29 +2260,34 @@ class SpeculativeConfig: ...@@ -2241,29 +2260,34 @@ class SpeculativeConfig:
self.draft_parallel_config) self.draft_parallel_config)
# Validate and set draft token acceptance related settings. # Validate and set draft token acceptance related settings.
if (self.draft_token_acceptance_method is None): if self.acceptance_method is None:
raise ValueError("draft_token_acceptance_method is not set. " raise ValueError("acceptance_method is not set. "
"Expected values are rejection_sampler or " "Expected values are rejection_sampler or "
"typical_acceptance_sampler.") "typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler' if (self.acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method and self.acceptance_method != 'typical_acceptance_sampler'):
!= 'typical_acceptance_sampler'):
raise ValueError( raise ValueError(
"Expected draft_token_acceptance_method to be either " "Expected acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it " "rejection_sampler or typical_acceptance_sampler. Instead it "
f"is {self.draft_token_acceptance_method}") f"is {self.acceptance_method}")
if (self.typical_acceptance_sampler_posterior_threshold < 0 if self.acceptance_method == "typical_acceptance_sampler" and (
or self.typical_acceptance_sampler_posterior_alpha < 0): (self.posterior_threshold is not None
and self.posterior_threshold < 0) or
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
raise ValueError( raise ValueError(
"Expected typical_acceptance_sampler_posterior_threshold " "Expected the posterior_threshold and posterior_alpha of "
"and typical_acceptance_sampler_posterior_alpha to be > 0. " "typical_acceptance_sampler to be > 0. "
"Instead found " "Instead found posterior_threshold = "
f"typical_acceptance_sampler_posterior_threshold = " f"{self.posterior_threshold} and posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_threshold} and " f"{self.posterior_alpha}")
f"typical_acceptance_sampler_posterior_alpha = "
f"{self.typical_acceptance_sampler_posterior_alpha}") if (self.disable_by_batch_size is not None
and self.disable_by_batch_size < 2):
raise ValueError("Expect the batch size threshold of disabling "
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")
@property @property
def num_lookahead_slots(self) -> int: def num_lookahead_slots(self) -> int:
...@@ -2276,8 +2300,8 @@ class SpeculativeConfig: ...@@ -2276,8 +2300,8 @@ class SpeculativeConfig:
return self.num_speculative_tokens return self.num_speculative_tokens
def __repr__(self) -> str: def __repr__(self) -> str:
if self.ngram_prompt_lookup_max > 0: if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0:
draft_model = "[ngram]" draft_model = "ngram"
else: else:
draft_model = self.draft_model_config.model draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens num_spec_tokens = self.num_speculative_tokens
...@@ -2785,12 +2809,14 @@ class DecodingConfig: ...@@ -2785,12 +2809,14 @@ class DecodingConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
]
backend = GuidedDecodingParams( backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name backend=self.guided_decoding_backend).backend_name
if backend not in valid_guided_backends: if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}," raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}") f" must be one of {valid_guided_backends}")
...@@ -3283,7 +3309,8 @@ class VllmConfig: ...@@ -3283,7 +3309,8 @@ class VllmConfig:
init=True) # type: ignore init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None speculative_config: SpeculativeConfig = field(default=None,
init=True) # type: ignore
decoding_config: Optional[DecodingConfig] = None decoding_config: Optional[DecodingConfig] = None
observability_config: Optional[ObservabilityConfig] = None observability_config: Optional[ObservabilityConfig] = None
prompt_adapter_config: Optional[PromptAdapterConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None
......
...@@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): ...@@ -341,8 +341,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
assert device in self._allocators assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate() return self._allocators[device].get_prefix_cache_hit_rate()
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for specified or all devices."""
if device:
return self._allocators[device].reset_prefix_cache()
success = True success = True
for allocator in self._allocators.values(): for allocator in self._allocators.values():
success = success and allocator.reset_prefix_cache() success = success and allocator.reset_prefix_cache()
......
...@@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC): ...@@ -305,7 +305,7 @@ class DeviceAwareBlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache.""" """Reset prefix cache."""
pass pass
......
...@@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -456,8 +456,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device) return self.block_allocator.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_allocator.reset_prefix_cache() return self.block_allocator.reset_prefix_cache(device)
def _can_swap(self, def _can_swap(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
...@@ -125,8 +125,8 @@ class BlockSpaceManager(ABC): ...@@ -125,8 +125,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for specified or all devices."""
pass pass
@abstractmethod @abstractmethod
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple from typing import List, Optional, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
...@@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager): ...@@ -92,7 +92,7 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1 return -1
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return True return True
def get_num_cached_tokens(self, seq: Sequence) -> int: def get_num_cached_tokens(self, seq: Sequence) -> int:
......
...@@ -634,8 +634,8 @@ class Scheduler: ...@@ -634,8 +634,8 @@ class Scheduler:
def get_prefix_cache_hit_rate(self, device: Device) -> float: def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device) return self.block_manager.get_prefix_cache_hit_rate(device)
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
return self.block_manager.reset_prefix_cache() return self.block_manager.reset_prefix_cache(device)
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
......
...@@ -6,16 +6,25 @@ from typing import Optional ...@@ -6,16 +6,25 @@ from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = get_current_vllm_config(
).parallel_config.distributed_executor_backend == "ray"
logger = init_logger(__name__)
if current_platform.is_tpu(): if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from torch_xla._internal import pjrt from torch_xla._internal import pjrt
from vllm.executor import ray_utils if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase): class TpuCommunicator(DeviceCommunicatorBase):
...@@ -33,19 +42,32 @@ class TpuCommunicator(DeviceCommunicatorBase): ...@@ -33,19 +42,32 @@ class TpuCommunicator(DeviceCommunicatorBase):
global_rank = self.global_rank global_rank = self.global_rank
global_world_size = self.global_world_size global_world_size = self.global_world_size
# Calculate how many TPU nodes are in the current deployment. This if USE_RAY:
# is the Ray placement group if it is deployed with Ray. Default logger.info("TpuCommunicator initialized with RAY")
# to the number of TPU nodes in the Ray cluster. The number of TPU # Calculate how many TPU nodes are in the current deployment. This
# nodes is computed by the total number of TPUs divided by the # is the Ray placement group if it is deployed with Ray. Default
# number of TPU accelerators per node, to account for clusters # to the number of TPU nodes in the Ray cluster. The number of TPU
# with both CPUs and TPUs. # nodes is computed by the total number of TPUs divided by the
num_nodes = ray_utils.get_num_tpu_nodes() # number of TPU accelerators per node, to account for clusters
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() # with both CPUs and TPUs.
if num_nodes_in_pg > 0: num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes = num_nodes_in_pg num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
local_world_size = global_world_size // num_nodes num_nodes = num_nodes_in_pg
local_rank = global_rank % local_world_size
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments. # Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU # On GKE, this is needed for libtpu and TPU driver to know which TPU
......
...@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union ...@@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
...@@ -37,6 +38,8 @@ class SimpleConnector(KVConnectorBase): ...@@ -37,6 +38,8 @@ class SimpleConnector(KVConnectorBase):
self.config = config.kv_transfer_config self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size self.tp_size = config.parallel_config.tensor_parallel_size
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
if self.config.kv_connector == "PyNcclConnector": if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
...@@ -167,8 +170,26 @@ class SimpleConnector(KVConnectorBase): ...@@ -167,8 +170,26 @@ class SimpleConnector(KVConnectorBase):
num_heads = int(model_config.num_key_value_heads / self.tp_size) num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads num_attention_heads = model_config.num_attention_heads
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads)) # Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + \
model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + \
model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
# query_lens contains new KV caches that are added to vLLM. # query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance # so we will send them to decode instance
...@@ -192,8 +213,12 @@ class SimpleConnector(KVConnectorBase): ...@@ -192,8 +213,12 @@ class SimpleConnector(KVConnectorBase):
for layer_id in range(start_layer, end_layer): for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer] kv_cache = kv_caches[layer_id - start_layer]
key_cache = kv_cache[0].reshape(-1, num_heads, head_size) if self.is_deepseek_mla and self.use_mla_opt:
value_cache = kv_cache[1].reshape(-1, num_heads, head_size) key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos] current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
...@@ -223,6 +248,8 @@ class SimpleConnector(KVConnectorBase): ...@@ -223,6 +248,8 @@ class SimpleConnector(KVConnectorBase):
# and hidden states. # and hidden states.
bypass_model_exec = True bypass_model_exec = True
model_config = model_executable.model.config
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens seq_lens = model_input.attn_metadata.seq_lens
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
...@@ -291,19 +318,35 @@ class SimpleConnector(KVConnectorBase): ...@@ -291,19 +318,35 @@ class SimpleConnector(KVConnectorBase):
kv_cache = kv_caches[i - model_executable.model.start_layer] kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i] layer = model_executable.model.layers[i]
key_cache, value_cache = kv_cache[0], kv_cache[1] if self.is_deepseek_mla and self.use_mla_opt:
ops.reshape_and_cache_flash( layer.self_attn.attn = layer.self_attn.mla_attn
keys[i - model_executable.model.start_layer].to( k_c_normed_k_pe = keys[
key_cache.device), i - model_executable.model.start_layer].to(
values[i - model_executable.model.start_layer].to( kv_cache.device).squeeze(1)
value_cache.device), k_c_normed = k_c_normed_k_pe[:, :model_config.kv_lora_rank]
key_cache, k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank:]
value_cache, ops.concat_and_cache_mla(
slot_mapping[start_pos:end_pos], k_c_normed,
layer.self_attn.attn.kv_cache_dtype, k_pe,
layer.self_attn.attn._k_scale, kv_cache,
layer.self_attn.attn._v_scale, slot_mapping[start_pos:end_pos],
) layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
hidden_or_intermediate_states_for_one_req.append(hidden) hidden_or_intermediate_states_for_one_req.append(hidden)
......
...@@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins ...@@ -26,7 +26,7 @@ from vllm.plugins import load_general_plugins
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
...@@ -40,7 +40,6 @@ DEVICE_OPTIONS = [ ...@@ -40,7 +40,6 @@ DEVICE_OPTIONS = [
"cuda", "cuda",
"neuron", "neuron",
"cpu", "cpu",
"openvino",
"tpu", "tpu",
"xpu", "xpu",
"hpu", "hpu",
...@@ -120,6 +119,7 @@ class EngineArgs: ...@@ -120,6 +119,7 @@ class EngineArgs:
block_size: Optional[int] = None block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None enable_prefix_caching: Optional[bool] = None
disable_sliding_window: bool = False disable_sliding_window: bool = False
disable_cascade_attn: bool = False
use_v2_block_manager: bool = True use_v2_block_manager: bool = True
swap_space: float = 4 # GiB swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = 0 # GiB
...@@ -177,7 +177,10 @@ class EngineArgs: ...@@ -177,7 +177,10 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar' guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
# Speculative decoding configuration.
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None speculative_draft_tensor_parallel_size: Optional[int] = None
...@@ -190,9 +193,9 @@ class EngineArgs: ...@@ -190,9 +193,9 @@ class EngineArgs:
spec_decoding_acceptance_method: str = 'rejection_sampler' spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None disable_logprobs_during_spec_decoding: Optional[bool] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
...@@ -338,9 +341,15 @@ class EngineArgs: ...@@ -338,9 +341,15 @@ class EngineArgs:
'CoreWeave. See the Tensorize vLLM Model script in the Examples ' 'CoreWeave. See the Tensorize vLLM Model script in the Examples '
'section for more information.\n' 'section for more information.\n'
'* "runai_streamer" will load the Safetensors weights using Run:ai' '* "runai_streamer" will load the Safetensors weights using Run:ai'
'Model Streamer \n' 'Model Streamer.\n'
'* "bitsandbytes" will load the weights using bitsandbytes ' '* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n') 'quantization.\n'
'* "sharded_state" will load weights from pre-sharded checkpoint '
'files, supporting efficient loading of tensor-parallel models\n'
'* "gguf" will load weights from GGUF format files (details '
'specified in https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n'
'* "mistral" will load weights from consolidated safetensors files '
'used by Mistral models.\n')
parser.add_argument( parser.add_argument(
'--config-format', '--config-format',
default=EngineArgs.config_format, default=EngineArgs.config_format,
...@@ -774,7 +783,11 @@ class EngineArgs: ...@@ -774,7 +783,11 @@ class EngineArgs:
const="True", const="True",
help='If set, the prefill requests can be chunked based on the ' help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.') 'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
type=nullable_str, type=nullable_str,
...@@ -1096,6 +1109,16 @@ class EngineArgs: ...@@ -1096,6 +1109,16 @@ class EngineArgs:
"using. This is used to parse the reasoning content into OpenAI " "using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.") "API format. Required for ``--enable-reasoning``.")
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
default=False,
help="Disable cascade attention for V1. While cascade attention "
"does not change the mathematical correctness, disabling it "
"could be useful for preventing potential numerical issues. "
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
return parser return parser
@classmethod @classmethod
...@@ -1141,6 +1164,7 @@ class EngineArgs: ...@@ -1141,6 +1164,7 @@ class EngineArgs:
max_seq_len_to_capture=self.max_seq_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs, max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
...@@ -1158,22 +1182,15 @@ class EngineArgs: ...@@ -1158,22 +1182,15 @@ class EngineArgs:
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or if(self.qlora_adapter_name_or_path is not None) and \
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes": self.quantization != "bitsandbytes":
raise ValueError( raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support " "QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}") f"'bitsandbytes' quantization, but got {self.quantization}")
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
return LoadConfig( return LoadConfig(
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
...@@ -1182,6 +1199,82 @@ class EngineArgs: ...@@ -1182,6 +1199,82 @@ class EngineArgs:
use_tqdm_on_load=self.use_tqdm_on_load, use_tqdm_on_load=self.use_tqdm_on_load,
) )
def create_speculative_config(
self,
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
enable_chunked_prefill: bool,
disable_log_stats: bool,
) -> Optional["SpeculativeConfig"]:
"""Initializes and returns a SpeculativeConfig object based on
`speculative_config`.
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
"enable_chunked_prefill": enable_chunked_prefill,
"disable_log_stats": disable_log_stats,
})
speculative_config = SpeculativeConfig.from_dict(
self.speculative_config)
return speculative_config
def create_engine_config( def create_engine_config(
self, self,
usage_context: Optional[UsageContext] = None, usage_context: Optional[UsageContext] = None,
...@@ -1228,6 +1321,8 @@ class EngineArgs: ...@@ -1228,6 +1321,8 @@ class EngineArgs:
else: else:
self._set_default_args_v0(model_config) self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
...@@ -1240,6 +1335,18 @@ class EngineArgs: ...@@ -1240,6 +1335,18 @@ class EngineArgs:
cpu_offload_gb=self.cpu_offload_gb, cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales, calculate_kv_scales=self.calculate_kv_scales,
) )
# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
placement_group = None
if is_in_ray_actor():
import ray
# This call initializes Ray automatically if it is not initialized,
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
...@@ -1252,36 +1359,17 @@ class EngineArgs: ...@@ -1252,36 +1359,17 @@ class EngineArgs:
self.tokenizer_pool_extra_config, self.tokenizer_pool_extra_config,
), ),
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
) )
speculative_config = SpeculativeConfig.maybe_create_spec_config( speculative_config = self.create_speculative_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill, enable_chunked_prefill=self.enable_chunked_prefill,
disable_log_stats=self.disable_log_stats, disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
) )
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
...@@ -1436,16 +1524,6 @@ class EngineArgs: ...@@ -1436,16 +1524,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.worker_cls != EngineArgs.worker_cls:
_raise_or_fallback(feature_name="--worker-cls",
recommend_to_remove=False)
return False
if self.worker_extension_cls != EngineArgs.worker_extension_cls:
_raise_or_fallback(feature_name="--worker-extension-cls",
recommend_to_remove=False)
return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps", _raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True) recommend_to_remove=True)
...@@ -1462,7 +1540,9 @@ class EngineArgs: ...@@ -1462,7 +1540,9 @@ class EngineArgs:
return False return False
# Only support Xgrammar for guided decoding so far. # Only support Xgrammar for guided decoding so far.
SUPPORTED_GUIDED_DECODING = ["xgrammar", "xgrammar:nofallback"] SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend", _raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False) recommend_to_remove=False)
...@@ -1554,7 +1634,7 @@ class EngineArgs: ...@@ -1554,7 +1634,7 @@ class EngineArgs:
if (self.speculative_model is not None if (self.speculative_model is not None
or self.num_speculative_tokens is not None): or self.num_speculative_tokens is not None):
# This is supported but experimental (handled below). # This is supported but experimental (handled below).
if self.speculative_model == "[ngram]": if self.speculative_model in ("ngram", "[ngram]"):
pass pass
else: else:
_raise_or_fallback(feature_name="Speculative Decoding", _raise_or_fallback(feature_name="Speculative Decoding",
...@@ -1570,7 +1650,7 @@ class EngineArgs: ...@@ -1570,7 +1650,7 @@ class EngineArgs:
# No FlashInfer or XFormers so far. # No FlashInfer or XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_MLA", "FLASHMLA" "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
...@@ -1602,7 +1682,8 @@ class EngineArgs: ...@@ -1602,7 +1682,8 @@ class EngineArgs:
return False return False
# ngram is supported on V1, but off by default for now. # ngram is supported on V1, but off by default for now.
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"): if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
return False return False
# Non-CUDA is supported on V1, but off by default for now. # Non-CUDA is supported on V1, but off by default for now.
...@@ -1683,7 +1764,7 @@ class EngineArgs: ...@@ -1683,7 +1764,7 @@ class EngineArgs:
# V1 should use the new scheduler by default. # V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default # Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls: if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = "vllm.v1.core.scheduler.Scheduler" self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
# When no user override, set the default values based on the usage # When no user override, set the default values based on the usage
# context. # context.
......
...@@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams ...@@ -35,7 +35,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind from vllm.utils import Device, deprecate_kwargs, weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
...@@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine):
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
...@@ -1216,8 +1215,9 @@ class AsyncLLMEngine(EngineClient): ...@@ -1216,8 +1215,9 @@ class AsyncLLMEngine(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
self.engine.stop_profile() self.engine.stop_profile()
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self,
self.engine.reset_prefix_cache() device: Optional[Device] = None) -> None:
self.engine.reset_prefix_cache(device)
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment