Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
...@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
...@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
...@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
...@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]], kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): ...@@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata return ROCmFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return self._cached_decode_metadata return self._cached_decode_metadata
class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
_metadata_cls = ROCmFlashAttentionMetadata
def _make_alibi_bias(alibi_slopes: torch.Tensor, def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: Optional[List[int]], seq_lens: Optional[List[int]],
...@@ -266,6 +277,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -266,6 +277,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# flash_attn_varlen_func) # flash_attn_varlen_func)
self.attn_func = triton_attention # flash_attn_varlen_func self.attn_func = triton_attention # flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend") logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
...@@ -298,7 +315,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -298,7 +315,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -338,7 +356,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -338,7 +356,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
...@@ -420,6 +439,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -420,6 +439,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seqlen_k=prefill_meta.max_prefill_seq_len, max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
) )
# common code for prefill # common code for prefill
...@@ -455,7 +476,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -455,7 +476,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
...@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
...@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, kv_scale) self.kv_cache_dtype, k_scale,
v_scale)
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens is not None
...@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
"""Attention backend utils""" """Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.utils import make_tensor_with_pad
# Error string(s) for encoder/decoder # Error string(s) for encoder/decoder
# unsupported attention scenarios # unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.") "with encoder/decoder models.")
PAD_SLOT_ID = -1
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int,
use_v2_block_manager: bool):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window)
return start_idx
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(self.runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
...@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias, ...@@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend): ...@@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata return XFormersMetadata
@staticmethod
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args( ...@@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
_metadata_cls = XFormersMetadata
class XFormersImpl(AttentionImpl[XFormersMetadata]): class XFormersImpl(AttentionImpl[XFormersMetadata]):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
...@@ -427,7 +437,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -427,7 +437,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value: Optional[torch.Tensor], value: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
...@@ -531,7 +542,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -531,7 +542,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value_cache, value_cache,
updated_slot_mapping, updated_slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale) k_scale, v_scale)
if attn_type != AttentionType.ENCODER: if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill. # Decoder self-attention supports chunked prefill.
...@@ -620,7 +631,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -620,7 +631,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
...@@ -9,7 +9,7 @@ from vllm.attention.selector import get_attn_backend ...@@ -9,7 +9,7 @@ from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
class Attention(nn.Module): class Attention(nn.Module):
...@@ -34,6 +34,7 @@ class Attention(nn.Module): ...@@ -34,6 +34,7 @@ class Attention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None, blocksparse_params: Optional[Dict[str, Any]] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if cache_config is not None: if cache_config is not None:
...@@ -47,29 +48,29 @@ class Attention(nn.Module): ...@@ -47,29 +48,29 @@ class Attention(nn.Module):
if num_kv_heads is None: if num_kv_heads is None:
num_kv_heads = num_heads num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored # The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with # when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along # expect the pre-quantized k/v_scale to be loaded along
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0 self._k_scale = 1.0
self._v_scale = 1.0
quant_method = quant_config.get_quant_method( quant_method = quant_config.get_quant_method(
self) if quant_config else None self, prefix=prefix) if quant_config else None
if quant_method is not None: if quant_method is not None:
assert isinstance(quant_method, Fp8KVCacheMethod) assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8 # TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior # checkpoint config and become the "auto" behavior
if "fp8" in self.kv_cache_dtype: if self.kv_cache_dtype == "fp8_e5m2":
if self.kv_cache_dtype == "fp8_e5m2": raise ValueError("fp8_e5m2 kv-cache is not supported with "
raise ValueError("fp8_e5m2 kv-cache is not supported with " "fp8 checkpoints.")
"fp8 checkpoints.") # If quantization is enabled, we make "k_scale" and "v_scale"
# When FP8 quantization is enabled, we make a parameter # parameters so that it can be loaded from the model checkpoint.
# "kv_scale" so that it can be loaded from FP8 checkpoint. # The k/v_scale will then be converted back to native float32
# The kv_scale will then be converted back to self._kv_scale # values after weight loading.
# in a native float32 value after weight loading. self.quant_method = quant_method
self.quant_method = quant_method self.quant_method.create_weights(self)
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model # During model initialization, the default dtype is set as the model
# weight and activation dtype. # weight and activation dtype.
...@@ -98,7 +99,8 @@ class Attention(nn.Module): ...@@ -98,7 +99,8 @@ class Attention(nn.Module):
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
self._kv_scale, self._k_scale,
self._v_scale,
attn_type=attn_type) attn_type=attn_type)
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -45,7 +45,8 @@ class PagedAttention: ...@@ -45,7 +45,8 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
*args, *args,
) -> None: ) -> None:
ipex_modules.PagedAttention.reshape_and_cache( ipex_modules.PagedAttention.reshape_and_cache(
...@@ -64,7 +65,8 @@ class PagedAttention: ...@@ -64,7 +65,8 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale: float, k_scale: float,
v_scale: float,
*args, *args,
) -> torch.Tensor: ) -> torch.Tensor:
output = torch.empty_like(query) output = torch.empty_like(query)
......
...@@ -66,7 +66,8 @@ class PagedAttention: ...@@ -66,7 +66,8 @@ class PagedAttention:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
) -> None: ) -> None:
ops.reshape_and_cache( ops.reshape_and_cache(
key, key,
...@@ -75,7 +76,8 @@ class PagedAttention: ...@@ -75,7 +76,8 @@ class PagedAttention:
value_cache, value_cache,
slot_mapping.flatten(), slot_mapping.flatten(),
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
@staticmethod @staticmethod
...@@ -90,7 +92,8 @@ class PagedAttention: ...@@ -90,7 +92,8 @@ class PagedAttention:
num_kv_heads: int, num_kv_heads: int,
scale: float, scale: float,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -135,7 +138,8 @@ class PagedAttention: ...@@ -135,7 +138,8 @@ class PagedAttention:
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank, tp_rank,
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
...@@ -172,7 +176,8 @@ class PagedAttention: ...@@ -172,7 +176,8 @@ class PagedAttention:
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank, tp_rank,
blocksparse_local_blocks, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_vert_stride,
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -136,7 +137,7 @@ def which_attn_to_use( ...@@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
...@@ -145,7 +146,7 @@ def which_attn_to_use( ...@@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8: if current_platform.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs. # Volta and Turing NVIDIA GPUs.
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing " "Cannot use FlashAttention-2 backend for Volta and Turing "
......
import enum import enum
import json import json
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, ...@@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.model_loader.loader import BaseModelLoader from vllm.model_executor.model_loader.loader import BaseModelLoader
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,6 +37,7 @@ _PP_SUPPORTED_MODELS = [ ...@@ -34,6 +37,7 @@ _PP_SUPPORTED_MODELS = [
"MistralForCausalLM", "MistralForCausalLM",
"Phi3ForCausalLM", "Phi3ForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"MixtralForCausalLM",
] ]
...@@ -237,7 +241,8 @@ class ModelConfig: ...@@ -237,7 +241,8 @@ class ModelConfig:
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if (self.quantization if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")): not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors")):
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
...@@ -431,6 +436,7 @@ class CacheConfig: ...@@ -431,6 +436,7 @@ class CacheConfig:
num_gpu_blocks_override: Optional[int] = None, num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False, enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
...@@ -439,6 +445,7 @@ class CacheConfig: ...@@ -439,6 +445,7 @@ class CacheConfig:
self.cache_dtype = cache_dtype self.cache_dtype = cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb
self._verify_args() self._verify_args()
self._verify_cache_dtype() self._verify_cache_dtype()
self._verify_prefix_caching() self._verify_prefix_caching()
...@@ -514,11 +521,12 @@ class TokenizerPoolConfig: ...@@ -514,11 +521,12 @@ class TokenizerPoolConfig:
pool type. pool type.
""" """
pool_size: int pool_size: int
pool_type: str pool_type: Union[str, Type["BaseTokenizerGroup"]]
extra_config: dict extra_config: dict
def __post_init__(self): def __post_init__(self):
if self.pool_type not in ("ray", ): if self.pool_type not in ("ray", ) and not isinstance(
self.pool_type, type):
raise ValueError(f"Unknown pool type: {self.pool_type}") raise ValueError(f"Unknown pool type: {self.pool_type}")
if not isinstance(self.extra_config, dict): if not isinstance(self.extra_config, dict):
raise ValueError("extra_config must be a dictionary.") raise ValueError("extra_config must be a dictionary.")
...@@ -582,12 +590,16 @@ class LoadConfig: ...@@ -582,12 +590,16 @@ class LoadConfig:
mainly for profiling. mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for "tensorizer" will use CoreWeave's tensorizer library for
fast weight loading. fast weight loading.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field( model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict) default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
...@@ -596,6 +608,13 @@ class LoadConfig: ...@@ -596,6 +608,13 @@ class LoadConfig:
model_loader_extra_config) model_loader_extra_config)
self._verify_load_format() self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns)
else:
self.ignore_patterns = ["original/**/*"]
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str): if not isinstance(self.load_format, str):
return return
...@@ -648,7 +667,8 @@ class ParallelConfig: ...@@ -648,7 +667,8 @@ class ParallelConfig:
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
ray_workers_use_nsight: bool = False, ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None, placement_group: Optional["PlacementGroup"] = None,
distributed_executor_backend: Optional[str] = None, distributed_executor_backend: Optional[Union[
str, Type["ExecutorBase"]]] = None,
) -> None: ) -> None:
self.pipeline_parallel_size = pipeline_parallel_size self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size self.tensor_parallel_size = tensor_parallel_size
...@@ -663,7 +683,7 @@ class ParallelConfig: ...@@ -663,7 +683,7 @@ class ParallelConfig:
if worker_use_ray: if worker_use_ray:
if self.distributed_executor_backend is None: if self.distributed_executor_backend is None:
self.distributed_executor_backend = "ray" self.distributed_executor_backend = "ray"
elif self.distributed_executor_backend != "ray": elif not self.use_ray:
raise ValueError(f"worker-use-ray can't be used with " raise ValueError(f"worker-use-ray can't be used with "
f"distributed executor backend " f"distributed executor backend "
f"'{self.distributed_executor_backend}'.") f"'{self.distributed_executor_backend}'.")
...@@ -698,16 +718,25 @@ class ParallelConfig: ...@@ -698,16 +718,25 @@ class ParallelConfig:
self._verify_args() self._verify_args()
self.rank = 0 self.rank = 0
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
def _verify_args(self) -> None: def _verify_args(self) -> None:
if (self.pipeline_parallel_size > 1 # Lazy import to avoid circular import
and self.distributed_executor_backend == "mp"): from vllm.executor.executor_base import ExecutorBase
raise NotImplementedError("Pipeline parallelism is not supported "
"yet with multiprocessing.") if self.distributed_executor_backend not in (
if self.distributed_executor_backend not in ("ray", "mp", None): "ray", "mp", None) and not (isinstance(
self.distributed_executor_backend, type) and issubclass(
self.distributed_executor_backend, ExecutorBase)):
raise ValueError( raise ValueError(
"Unrecognized distributed executor backend. Supported values " "Unrecognized distributed executor backend "
"are 'ray' or 'mp'.") f"{self.distributed_executor_backend}. Supported "
if self.distributed_executor_backend == "ray": "values are 'ray', 'mp' or custom ExecutorBase subclass.")
if self.use_ray:
from vllm.executor import ray_utils from vllm.executor import ray_utils
ray_utils.assert_ray_available() ray_utils.assert_ray_available()
if is_hip(): if is_hip():
...@@ -715,8 +744,7 @@ class ParallelConfig: ...@@ -715,8 +744,7 @@ class ParallelConfig:
logger.info( logger.info(
"Disabled the custom all-reduce kernel because it is not " "Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.") "supported on AMD GPUs.")
if self.ray_workers_use_nsight and ( if self.ray_workers_use_nsight and not self.use_ray:
not self.distributed_executor_backend == "ray"):
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.") "run with Ray.")
...@@ -775,7 +803,9 @@ class SchedulerConfig: ...@@ -775,7 +803,9 @@ class SchedulerConfig:
# for higher throughput. # for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill: if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).") logger.info(
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -868,6 +898,7 @@ class SpeculativeConfig: ...@@ -868,6 +898,7 @@ class SpeculativeConfig:
draft_token_acceptance_method: str, draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: Optional[float], typical_acceptance_sampler_posterior_threshold: Optional[float],
typical_acceptance_sampler_posterior_alpha: Optional[float], typical_acceptance_sampler_posterior_alpha: Optional[float],
disable_logprobs: Optional[bool],
) -> Optional["SpeculativeConfig"]: ) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None. """Create a SpeculativeConfig if possible, else return None.
...@@ -917,6 +948,11 @@ class SpeculativeConfig: ...@@ -917,6 +948,11 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha (Optional[float]): typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler. TypicalAcceptanceSampler.
disable_logprobs (Optional[bool]): If set to True, token log
probabilities are not returned during speculative decoding.
If set to False, token log probabilities are returned
according to the log probability settings in SamplingParams.
If not specified, it defaults to True.
Returns: Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
...@@ -1029,6 +1065,8 @@ class SpeculativeConfig: ...@@ -1029,6 +1065,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold = 0.09 typical_acceptance_sampler_posterior_threshold = 0.09
if typical_acceptance_sampler_posterior_alpha is None: if typical_acceptance_sampler_posterior_alpha is None:
typical_acceptance_sampler_posterior_alpha = 0.3 typical_acceptance_sampler_posterior_alpha = 0.3
if disable_logprobs is None:
disable_logprobs = True
return SpeculativeConfig( return SpeculativeConfig(
draft_model_config, draft_model_config,
...@@ -1042,6 +1080,7 @@ class SpeculativeConfig: ...@@ -1042,6 +1080,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\ typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
) )
@staticmethod @staticmethod
...@@ -1126,6 +1165,7 @@ class SpeculativeConfig: ...@@ -1126,6 +1165,7 @@ class SpeculativeConfig:
draft_token_acceptance_method: str, draft_token_acceptance_method: str,
typical_acceptance_sampler_posterior_threshold: float, typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float, typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
): ):
"""Create a SpeculativeConfig object. """Create a SpeculativeConfig object.
...@@ -1152,6 +1192,12 @@ class SpeculativeConfig: ...@@ -1152,6 +1192,12 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_alpha (Optional[float]): typical_acceptance_sampler_posterior_alpha (Optional[float]):
A scaling factor for the entropy-based threshold in the A scaling factor for the entropy-based threshold in the
TypicalAcceptanceSampler. TypicalAcceptanceSampler.
disable_logprobs: If set to True, token log probabilities will not
be returned even if requested by sampling parameters. This
reduces latency by skipping logprob calculation in proposal
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
""" """
self.draft_model_config = draft_model_config self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config self.draft_parallel_config = draft_parallel_config
...@@ -1165,6 +1211,7 @@ class SpeculativeConfig: ...@@ -1165,6 +1211,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold typical_acceptance_sampler_posterior_threshold
self.typical_acceptance_sampler_posterior_alpha = \ self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self._verify_args() self._verify_args()
...@@ -1436,23 +1483,32 @@ def _get_and_verify_max_len( ...@@ -1436,23 +1483,32 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None) rope_scaling = getattr(hf_config, "rope_scaling", None)
# The correct one should be "longrope", kept "su" here if rope_scaling is not None:
# to be backward compatible if "type" in rope_scaling:
if rope_scaling is not None and rope_scaling["type"] != "su" \ rope_type = rope_scaling["type"]
and rope_scaling["type"] != "longrope": elif "rope_type" in rope_scaling:
if disable_sliding_window: rope_type = rope_scaling["rope_type"]
# TODO(robertgshaw): Find a model that supports rope_scaling else:
# with sliding window to see if this case should be allowed. raise ValueError(
raise NotImplementedError( "rope_scaling must have a 'type' or 'rope_type' key.")
"Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can " # The correct one should be "longrope", kept "su" here
"investigate.") # to be backward compatible
assert "factor" in rope_scaling if rope_type not in ("su", "longrope", "llama3"):
scaling_factor = rope_scaling["factor"] if disable_sliding_window:
if rope_scaling["type"] == "yarn": # TODO(robertgshaw): Find a model that supports rope_scaling
derived_max_model_len = rope_scaling[ # with sliding window to see if this case should be allowed.
"original_max_position_embeddings"] raise NotImplementedError(
derived_max_model_len *= scaling_factor "Disabling sliding window is not supported for models "
"with rope_scaling. Please raise an issue so we can "
"investigate.")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor
# If the user specified a max length, make sure it is smaller than the # If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config. # derived length from the HF model config.
......
from pathlib import Path
from typing import Mapping, Optional
from urllib.parse import urlparse
import aiohttp
import requests
from vllm.version import __version__ as VLLM_VERSION
class HTTPConnection:
"""Helper class to send HTTP requests."""
def __init__(self, *, reuse_client: bool = True) -> None:
super().__init__()
self.reuse_client = reuse_client
self._sync_client: Optional[requests.Session] = None
self._async_client: Optional[aiohttp.ClientSession] = None
def get_sync_client(self) -> requests.Session:
if self._sync_client is None or not self.reuse_client:
self._sync_client = requests.Session()
return self._sync_client
# NOTE: We intentionally use an async function even though it is not
# required, so that the client is only accessible inside async event loop
async def get_async_client(self) -> aiohttp.ClientSession:
if self._async_client is None or not self.reuse_client:
self._async_client = aiohttp.ClientSession()
return self._async_client
def _validate_http_url(self, url: str):
parsed_url = urlparse(url)
if parsed_url.scheme not in ("http", "https"):
raise ValueError("Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'.")
def _headers(self, **extras: str) -> Mapping[str, str]:
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
def get_response(
self,
url: str,
*,
stream: bool = False,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)
client = self.get_sync_client()
extra_headers = extra_headers or {}
return client.get(url,
headers=self._headers(**extra_headers),
stream=stream,
timeout=timeout)
async def get_async_response(
self,
url: str,
*,
timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None,
):
self._validate_http_url(url)
client = await self.get_async_client()
extra_headers = extra_headers or {}
return client.get(url,
headers=self._headers(**extra_headers),
timeout=timeout)
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.content
async def async_get_bytes(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> bytes:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.read()
def get_text(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.text
async def async_get_text(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.text()
def get_json(self, url: str, *, timeout: Optional[float] = None) -> str:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
return r.json()
async def async_get_json(
self,
url: str,
*,
timeout: Optional[float] = None,
) -> str:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
return await r.json()
def download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
with self.get_response(url, timeout=timeout) as r:
r.raise_for_status()
with save_path.open("wb") as f:
for chunk in r.iter_content(chunk_size):
f.write(chunk)
return save_path
async def async_download_file(
self,
url: str,
save_path: Path,
*,
timeout: Optional[float] = None,
chunk_size: int = 128,
) -> Path:
async with await self.get_async_response(url, timeout=timeout) as r:
r.raise_for_status()
with save_path.open("wb") as f:
async for chunk in r.content.iter_chunked(chunk_size):
f.write(chunk)
return save_path
global_http_connection = HTTPConnection()
"""The global :class:`HTTPConnection` instance used by vLLM."""
import math
from typing import List, Optional from typing import List, Optional
from vllm.core.block.common import BlockList from vllm.core.block.common import BlockList
...@@ -337,10 +338,17 @@ class BlockTable: ...@@ -337,10 +338,17 @@ class BlockTable:
This is required for the scheduler to determine whether a sequence can This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted. continue generation, or if it must be preempted.
""" """
# Math below is equivalent to:
# all_token_ids = token_ids + [-1] * num_lookahead_slots
# token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
# return len(token_blocks)
all_token_ids = token_ids + [-1] * num_lookahead_slots num_token_ids = len(token_ids) + num_lookahead_slots
token_blocks = self._chunk_token_blocks_for_append(all_token_ids) first_chunk_size = self._block_size - (self._num_full_slots %
return len(token_blocks) self._block_size)
num_token_blocks = (1 + math.ceil(
(num_token_ids - first_chunk_size) / self._block_size))
return num_token_blocks
def _chunk_token_blocks_for_append( def _chunk_token_blocks_for_append(
self, token_ids: List[int]) -> List[List[int]]: self, token_ids: List[int]) -> List[List[int]]:
...@@ -351,6 +359,7 @@ class BlockTable: ...@@ -351,6 +359,7 @@ class BlockTable:
""" """
first_chunk_size = self._block_size - (self._num_full_slots % first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size) self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list( token_blocks = [token_ids[:first_chunk_size]]
token_ids[first_chunk_size:], self._block_size) token_blocks.extend(
chunk_list(token_ids[first_chunk_size:], self._block_size))
return token_blocks return token_blocks
...@@ -552,9 +552,12 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -552,9 +552,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# runner. # runner.
# It returns a list of int although type annotation says list of string. # It returns a list of int although type annotation says list of string.
if len(computed_seq_block_ids) == 1:
return computed_seq_block_ids[0]
return commonprefix([ return commonprefix([
ids for ids in computed_seq_block_ids # type: ignore ids for ids in computed_seq_block_ids # type: ignore
if ids != [] if ids
]) ])
def get_num_blocks_touched(self, def get_num_blocks_touched(self,
......
...@@ -374,6 +374,7 @@ class Scheduler: ...@@ -374,6 +374,7 @@ class Scheduler:
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs(): for seq in aborted_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
......
...@@ -189,10 +189,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -189,10 +189,10 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
path = os.path.expanduser( path = os.path.join(
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" envs.VLLM_CACHE_ROOT,
) f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0) if ((not is_distributed or get_world_group().local_rank == 0)
......
...@@ -108,8 +108,15 @@ class ShmRingBuffer: ...@@ -108,8 +108,15 @@ class ShmRingBuffer:
# created by the process. The following patch is a workaround. # created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register", with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None): lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name) try:
assert self.shared_memory.size == self.total_bytes_of_buffer self.shared_memory = shared_memory.SharedMemory(name=name)
assert (
self.shared_memory.size == self.total_bytes_of_buffer)
except FileNotFoundError:
# we might deserialize the object in a different node
# in this case, this object is not used,
# and we should suppress the error
pass
def __reduce__(self): def __reduce__(self):
return ( return (
...@@ -119,9 +126,10 @@ class ShmRingBuffer: ...@@ -119,9 +126,10 @@ class ShmRingBuffer:
) )
def __del__(self): def __del__(self):
self.shared_memory.close() if hasattr(self, "shared_memory"):
if self.is_creator: self.shared_memory.close()
self.shared_memory.unlink() if self.is_creator:
self.shared_memory.unlink()
@contextmanager @contextmanager
def get_data(self, current_idx: int): def get_data(self, current_idx: int):
...@@ -170,7 +178,7 @@ class MessageQueue: ...@@ -170,7 +178,7 @@ class MessageQueue:
self.n_remote_reader = n_remote_reader self.n_remote_reader = n_remote_reader
if connect_ip is None: if connect_ip is None:
connect_ip = get_ip() connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
context = Context() context = Context()
...@@ -230,6 +238,8 @@ class MessageQueue: ...@@ -230,6 +238,8 @@ class MessageQueue:
remote_sync_port=remote_sync_port, remote_sync_port=remote_sync_port,
) )
logger.info("vLLM message queue communication handle: %s", self.handle)
def export_handle(self) -> Handle: def export_handle(self) -> Handle:
return self.handle return self.handle
...@@ -335,8 +345,8 @@ class MessageQueue: ...@@ -335,8 +345,8 @@ class MessageQueue:
time.sleep(RINGBUFFER_SLEEP_INTERVAL) time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user # if we wait for a long time, we should warn the user
if time.monotonic( if (time.monotonic() - start_time >
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning( logger.warning(
"No available block found in %s second. ", "No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
...@@ -389,8 +399,8 @@ class MessageQueue: ...@@ -389,8 +399,8 @@ class MessageQueue:
time.sleep(RINGBUFFER_SLEEP_INTERVAL) time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# if we wait for a long time, we should warn the user # if we wait for a long time, we should warn the user
if time.monotonic( if (time.monotonic() - start_time >
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.warning( logger.warning(
"No available block found in %s second. ", "No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
...@@ -426,7 +436,6 @@ class MessageQueue: ...@@ -426,7 +436,6 @@ class MessageQueue:
def dequeue(self): def dequeue(self):
if self._is_local_reader: if self._is_local_reader:
overflow = False
with self.acquire_read() as buf: with self.acquire_read() as buf:
overflow = buf[0] == 1 overflow = buf[0] == 1
if not overflow: if not overflow:
......
...@@ -2,16 +2,24 @@ import argparse ...@@ -2,16 +2,24 @@ import argparse
import dataclasses import dataclasses
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig) SpeculativeConfig, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
logger = init_logger(__name__)
def nullable_str(val: str): def nullable_str(val: str):
if not val or val == "None": if not val or val == "None":
...@@ -36,7 +44,11 @@ class EngineArgs: ...@@ -36,7 +44,11 @@ class EngineArgs:
seed: int = 0 seed: int = 0
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
worker_use_ray: bool = False worker_use_ray: bool = False
distributed_executor_backend: Optional[str] = None # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[str,
Type[ExecutorBase]]] = None
pipeline_parallel_size: int = 1 pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1 tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None max_parallel_loading_workers: Optional[int] = None
...@@ -45,6 +57,7 @@ class EngineArgs: ...@@ -45,6 +57,7 @@ class EngineArgs:
disable_sliding_window: bool = False disable_sliding_window: bool = False
use_v2_block_manager: bool = False use_v2_block_manager: bool = False
swap_space: int = 4 # GiB swap_space: int = 4 # GiB
cpu_offload_gb: int = 0 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
...@@ -61,7 +74,10 @@ class EngineArgs: ...@@ -61,7 +74,10 @@ class EngineArgs:
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0 tokenizer_pool_size: int = 0
tokenizer_pool_type: str = "ray" # Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[dict] = None tokenizer_pool_extra_config: Optional[dict] = None
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
...@@ -79,10 +95,11 @@ class EngineArgs: ...@@ -79,10 +95,11 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None model_loader_extra_config: Optional[dict] = None
ignore_patterns: Optional[Union[str, List[str]]] = None
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = None
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False enable_chunked_prefill: Optional[bool] = None
guided_decoding_backend: str = 'outlines' guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration. # Speculative decoding configuration.
...@@ -97,6 +114,7 @@ class EngineArgs: ...@@ -97,6 +114,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold: Optional[float] = None typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
...@@ -303,6 +321,20 @@ class EngineArgs: ...@@ -303,6 +321,20 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.') help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight,'
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is'
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument( parser.add_argument(
'--gpu-memory-utilization', '--gpu-memory-utilization',
type=float, type=float,
...@@ -480,7 +512,10 @@ class EngineArgs: ...@@ -480,7 +512,10 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
action='store_true', action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the ' help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.') 'max_num_batched_tokens.')
...@@ -565,6 +600,18 @@ class EngineArgs: ...@@ -565,6 +600,18 @@ class EngineArgs:
'to sqrt of --typical-acceptance-sampler-posterior-threshold ' 'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3') 'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
type=bool,
default=EngineArgs.disable_logprobs_during_spec_decoding,
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=nullable_str, type=nullable_str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
...@@ -573,6 +620,14 @@ class EngineArgs: ...@@ -573,6 +620,14 @@ class EngineArgs:
'corresponding to the chosen load_format. ' 'corresponding to the chosen load_format. '
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
'parsed into a dictionary.') 'parsed into a dictionary.')
parser.add_argument(
'--ignore-patterns',
action="append",
type=str,
default=[],
help="The pattern(s) to ignore when loading the model."
"Default to 'original/**/*' to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument( parser.add_argument(
'--preemption-mode', '--preemption-mode',
type=str, type=str,
...@@ -633,6 +688,11 @@ class EngineArgs: ...@@ -633,6 +688,11 @@ class EngineArgs:
raise ValueError( raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support " "BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}") f"'bitsandbytes' quantization, but got {self.quantization}")
assert self.cpu_offload_gb >= 0, (
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
device_config = DeviceConfig(device=self.device) device_config = DeviceConfig(device=self.device)
...@@ -666,7 +726,9 @@ class EngineArgs: ...@@ -666,7 +726,9 @@ class EngineArgs:
cache_dtype=self.kv_cache_dtype, cache_dtype=self.kv_cache_dtype,
num_gpu_blocks_override=self.num_gpu_blocks_override, num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(), sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching) enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
)
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
...@@ -681,6 +743,38 @@ class EngineArgs: ...@@ -681,6 +743,38 @@ class EngineArgs:
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,
distributed_executor_backend=self.distributed_executor_backend) distributed_executor_backend=self.distributed_executor_backend)
max_model_len = model_config.max_model_len
use_long_context = max_model_len > 32768
if self.enable_chunked_prefill is None:
# If not explicitly set, enable chunked prefill by default for
# long context (> 32K) models. This is to avoid OOM errors in the
# initial memory profiling phase.
if use_long_context:
is_gpu = device_config.device_type == "cuda"
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora
and not self.enable_prompt_adapter
and not self.enable_prefix_caching):
self.enable_chunked_prefill = True
logger.warning(
"Chunked prefill is enabled by default for models with "
"max_model_len > 32K. Currently, chunked prefill might "
"not work with some features or models. If you "
"encounter any issues, please disable chunked prefill "
"by setting --enable-chunked-prefill=False.")
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = False
if not self.enable_chunked_prefill and use_long_context:
logger.warning(
"The model has a long context length (%s). This may cause OOM "
"errors during the initial memory profiling phase, or result "
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)
speculative_config = SpeculativeConfig.maybe_create_spec_config( speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
...@@ -702,6 +796,7 @@ class EngineArgs: ...@@ -702,6 +796,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold, typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self. typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha, typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
...@@ -738,6 +833,7 @@ class EngineArgs: ...@@ -738,6 +833,7 @@ class EngineArgs:
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
) )
prompt_adapter_config = PromptAdapterConfig( prompt_adapter_config = PromptAdapterConfig(
...@@ -779,7 +875,6 @@ class AsyncEngineArgs(EngineArgs): ...@@ -779,7 +875,6 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine.""" """Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False engine_use_ray: bool = False
disable_log_requests: bool = False disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args(parser: FlexibleArgumentParser, def add_cli_args(parser: FlexibleArgumentParser,
...@@ -793,15 +888,21 @@ class AsyncEngineArgs(EngineArgs): ...@@ -793,15 +888,21 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='Disable logging requests.') help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser return parser
class StoreBoolean(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
setattr(namespace, self.dest, True)
elif values.lower() == "false":
setattr(namespace, self.dest, False)
else:
raise ValueError(f"Invalid boolean value: {values}. "
"Expected 'true' or 'false'.")
# These functions are used by sphinx to build the documentation # These functions are used by sphinx to build the documentation
def _engine_args_parser(): def _engine_args_parser():
return EngineArgs.add_cli_args(FlexibleArgumentParser()) return EngineArgs.add_cli_args(FlexibleArgumentParser())
......
import asyncio import asyncio
import time import time
from functools import partial from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Set, Tuple, Type, Union) Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm.envs as envs import vllm.envs as envs
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -129,7 +131,10 @@ class RequestTracker: ...@@ -129,7 +131,10 @@ class RequestTracker:
"""Process a request output from the engine.""" """Process a request output from the engine."""
request_id = request_output.request_id request_id = request_output.request_id
self._request_streams[request_id].put(request_output) # Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if (stream := self._request_streams.get(request_id)) is not None:
stream.put(request_output)
if request_output.finished: if request_output.finished:
if verbose: if verbose:
logger.info("Finished request %s.", request_id) logger.info("Finished request %s.", request_id)
...@@ -146,7 +151,10 @@ class RequestTracker: ...@@ -146,7 +151,10 @@ class RequestTracker:
logger.info("Finished request %s.", request_id) logger.info("Finished request %s.", request_id)
self.abort_request(request_id) self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream: **engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background """Add a request to be sent to the engine on the next background
loop iteration.""" loop iteration."""
...@@ -161,6 +169,9 @@ class RequestTracker: ...@@ -161,6 +169,9 @@ class RequestTracker:
self.new_requests_event.set() self.new_requests_event.set()
if verbose:
logger.info("Added request %s.", request_id)
return stream return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None: def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
...@@ -294,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -294,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine):
return self.input_processor(llm_inputs) return self.input_processor(llm_inputs)
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
...@@ -348,8 +359,6 @@ class AsyncLLMEngine: ...@@ -348,8 +359,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`. *args: Arguments for :class:`LLMEngine`.
...@@ -363,13 +372,11 @@ class AsyncLLMEngine: ...@@ -363,13 +372,11 @@ class AsyncLLMEngine:
engine_use_ray: bool, engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
self.log_requests = log_requests self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs) self.engine = self._init_engine(*args, **kwargs)
self.background_loop: Optional[asyncio.Future] = None self.background_loop: Optional[asyncio.Future] = None
...@@ -384,24 +391,19 @@ class AsyncLLMEngine: ...@@ -384,24 +391,19 @@ class AsyncLLMEngine:
self._request_tracker: RequestTracker self._request_tracker: RequestTracker
@classmethod @classmethod
def from_engine_args( def _get_executor_cls(
cls, cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):
if engine_config.device_config.device_type == "neuron": if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
...@@ -440,17 +442,37 @@ class AsyncLLMEngine: ...@@ -440,17 +442,37 @@ class AsyncLLMEngine:
else: else:
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync executor_class = GPUExecutorAsync
return executor_class
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_args.engine_use_ray:
from vllm.executor import ray_utils
ray_utils.assert_ray_available()
executor_class = cls._get_executor_cls(engine_config)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls( engine = cls(
distributed_executor_backend == "ray", executor_class.uses_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop, start_engine_loop=start_engine_loop,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers,
) )
return engine return engine
...@@ -477,11 +499,16 @@ class AsyncLLMEngine: ...@@ -477,11 +499,16 @@ class AsyncLLMEngine:
self.set_errored(exc) self.set_errored(exc)
self._request_tracker.propagate_exception(exc) self._request_tracker.propagate_exception(exc)
async def get_tokenizer(self) -> "PreTrainedTokenizer": async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
if self.engine_use_ray: if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore return await self.engine.get_tokenizer.remote( # type: ignore
else: lora_request)
return self.engine.get_tokenizer()
return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))
def start_background_loop(self) -> None: def start_background_loop(self) -> None:
"""Start the background loop.""" """Start the background loop."""
...@@ -641,30 +668,9 @@ class AsyncLLMEngine: ...@@ -641,30 +668,9 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
...@@ -680,6 +686,7 @@ class AsyncLLMEngine: ...@@ -680,6 +686,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
verbose=self.log_requests,
inputs=inputs, inputs=inputs,
params=params, params=params,
arrival_time=arrival_time, arrival_time=arrival_time,
...@@ -695,7 +702,7 @@ class AsyncLLMEngine: ...@@ -695,7 +702,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
...@@ -778,7 +785,7 @@ class AsyncLLMEngine: ...@@ -778,7 +785,7 @@ class AsyncLLMEngine:
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]: ) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
...@@ -856,7 +863,7 @@ class AsyncLLMEngine: ...@@ -856,7 +863,7 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
*, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or
...@@ -957,3 +964,19 @@ class AsyncLLMEngine: ...@@ -957,3 +964,19 @@ class AsyncLLMEngine:
) )
else: else:
return self.engine.is_tracing_enabled() return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, import vllm.envs as envs
LoRAConfig, ModelConfig, MultiModalConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
ObservabilityConfig, ParallelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
...@@ -375,18 +377,20 @@ class LLMEngine: ...@@ -375,18 +377,20 @@ class LLMEngine:
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
@classmethod @classmethod
def from_engine_args( def _get_executor_cls(cls,
cls, engine_config: EngineConfig) -> Type[ExecutorBase]:
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = ( distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend) engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class. # Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron": if isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu": elif engine_config.device_config.device_type == "tpu":
...@@ -413,17 +417,35 @@ class LLMEngine: ...@@ -413,17 +417,35 @@ class LLMEngine:
elif distributed_executor_backend == "mp": elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import ( from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutor) MultiprocessingGPUExecutor)
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
"multiprocessing distributed executor backend does not "
"support VLLM_USE_RAY_SPMD_WORKER=1")
executor_class = MultiprocessingGPUExecutor executor_class = MultiprocessingGPUExecutor
else: else:
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor executor_class = GPUExecutor
return executor_class
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls( engine = cls(
**engine_config.to_dict(), **engine_config.to_dict(),
executor_class=executor_class, executor_class=executor_class,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
usage_context=usage_context, usage_context=usage_context,
stat_loggers=stat_loggers,
) )
return engine return engine
def __reduce__(self): def __reduce__(self):
...@@ -448,8 +470,11 @@ class LLMEngine: ...@@ -448,8 +470,11 @@ class LLMEngine:
return self.tokenizer return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(
return self.get_tokenizer_group().get_lora_tokenizer(None) self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": sequence: Sequence) -> "PreTrainedTokenizer":
...@@ -498,7 +523,7 @@ class LLMEngine: ...@@ -498,7 +523,7 @@ class LLMEngine:
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> None: ) -> None:
# Create the sequences. # Create the sequences.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
...@@ -579,7 +604,7 @@ class LLMEngine: ...@@ -579,7 +604,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
...@@ -653,7 +678,7 @@ class LLMEngine: ...@@ -653,7 +678,7 @@ class LLMEngine:
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: float, arrival_time: float,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment