Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -12,9 +12,10 @@ import torch
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
FlexibleArgumentParser = None
......@@ -376,6 +377,20 @@ class Platform:
or parallel_config.distributed_executor_backend
== "external_launcher")
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return False
@classmethod
def use_custom_allreduce(cls) -> bool:
"""
Returns if custom allreduce is supported on the current platform
"""
return False
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
# SPDX-License-Identifier: Apache-2.0
import os
from functools import lru_cache, wraps
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional
import torch
......@@ -12,15 +12,17 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
try:
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
amdsmi_init, amdsmi_shut_down, AmdSmiException, amdsmi_topo_get_link_type)
from amdsmi import (AmdSmiException, amdsmi_get_gpu_asic_info,
amdsmi_get_processor_handles, amdsmi_init,
amdsmi_shut_down, amdsmi_topo_get_link_type)
except ImportError as e:
logger.warning("Failed to import from amdsmi with %r", e)
......@@ -97,6 +99,26 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id
@cache
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int,
sliding_window: int) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_NAVI = "gfx1" in GPU_ARCH
ON_MI250_MI300 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])
# rocm custom page attention not support on navi (gfx1*)
return (ON_MI250_MI300 and not ON_NAVI
and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
......@@ -135,22 +157,15 @@ class RocmPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
def get_device_capability(cls,
device_id: int = 0
) -> Optional[DeviceCapability]:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
@classmethod
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return torch.cuda.get_device_name(device_id)
@staticmethod
def is_fully_connected_nvlink_or_xgmi(
physical_device_ids: List[int]) -> bool:
@with_amdsmi_context
def is_fully_connected(physical_device_ids: List[int]) -> bool:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""
......@@ -172,6 +187,15 @@ class RocmPlatform(Platform):
return False
return True
@classmethod
@with_amdsmi_context
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
handle = amdsmi_get_processor_handles()[physical_device_id]
# return amdsmi_get_gpu_asic_info(handle)["market_name"]
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
......@@ -276,3 +300,15 @@ class RocmPlatform(Platform):
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on AMD gpus is experimental
return True
@classmethod
def use_custom_allreduce(cls) -> bool:
# We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs)
......@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None
logger = init_logger(__name__)
......@@ -127,3 +128,8 @@ class TpuPlatform(Platform):
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental
return True
......@@ -2,7 +2,11 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
__all__ = [
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
"ReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"GraniteReasoningParser",
]
......@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@abstractmethod
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
......@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
@abstractmethod
def extract_reasoning_content_streaming(
self,
previous_text: str,
......@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager:
......@@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in "
"reasoning_parsers")
raise KeyError(
f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod
def _register_module(cls,
def _register_module(
cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
force: bool = True) -> None:
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
......@@ -152,7 +144,8 @@ class ReasoningParserManager:
cls,
name: Optional[Union[str, list[str]]] = None,
force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]:
module: Union[type, None] = None,
) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
......
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence
from typing import Optional, Union
......@@ -8,9 +7,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
......@@ -24,39 +22,38 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
if self.end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
......@@ -77,22 +74,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
self.start_token_id, self.end_token_id
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
if self.start_token_id in previous_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
......@@ -100,17 +99,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids:
elif self.start_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
len(self.start_token):end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
......@@ -119,15 +119,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
......@@ -137,25 +139,34 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""
Extract reasoning content from the model output.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
Returns:
tuple[Optional[str], Optional[str]]: reasoning content and content
"""
# Check if the start token is present in the model output, remove it
# if it is present.
model_output_parts = model_output.partition(self.start_token)
model_output = model_output_parts[2] if model_output_parts[
1] else model_output_parts[0]
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
if self.end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
final_output = model_output[end_index:]
if len(final_output) == 0:
return reasoning_content, None
return reasoning_content, final_output
reasoning_content, _, content = model_output.partition(
self.end_token)
# If the end token is not found, return the model output as is.
# It should not happen since we already checked for the presence
# of the end token.
# If generation stops right after end-of-think, return null content
final_content = content or None
return reasoning_content, final_content
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence
from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("granite")
class GraniteReasoningParser(ReasoningParser):
"""
Reasoning parser for IBM Granite.
IBM granite models currently use "Here is my thought process:"
and "Here is my response:" to separate its thinking / response outputs.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# NOTE: There have been some observed occurrences of quantized
# instances of the current models using "Here's" instead of "Here is",
# so to be safe, we match on both.
self.think_start_expr = r"(?:Here's|Here is) my thought process:"
self.response_start_expr = r"(?:Here's|Here is) my response:"
self.reasoning_regex = re.compile(
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
re.DOTALL)
self.valid_think_starts = [
"Here's my thought process:", "Here is my thought process:"
]
self.valid_response_starts = [
"Here's my response:", "Here is my response:"
]
# Substrings to match for sequence boundaries on raw text
self.seq_boundary_end = ":"
self.seq_boundary_start = "Here"
# The longest any thinking / start of response message can be
self.longest_think_start = max(
len(think_start) for think_start in self.valid_think_starts)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match = self.reasoning_regex.findall(model_output)
if not re_match:
return None, model_output
reasoning_content, response_content = re_match[0]
if not response_content:
return reasoning_content, None
return reasoning_content, response_content
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""Extract the reasoning content / content emitted by granite models;
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
NOTE: Granite models do not use a special token to start their reasoning
and response sections; instead they have token sequences, e.g.,
Here is my thought process: Foo Here is my response: Bar
This increases the complexity of correctly handling streams, since we
need to watch for specific sequences and correctly parse them without
dropping content that is potentially overlapping & spanning multiple
delta messages.
Args:
previous_text (str): Previous text outside of this delta message.
current_text (str): Previous text + delta text.
delta_text (str): Text to consider and parse content from.
previous_token_ids (Sequence[int]): Token IDs of previous_text.
current_token_ids (Sequence[int]): Token IDs of current_text.
delta_token_ids (Sequence[int]): Token IDs of delta_text.
Returns:
Union[DeltaMessage, None]
DeltaMessage with either reasoning content or content, or None.
"""
reasoning_content, resp_seq_len, content = self._get_content_sections(
current_text)
# Either we haven't finished the start of the reasoning sequence,
# or the model is generating something unexpected.
if not reasoning_content:
delta_message = self._get_delta_message_with_no_reasoning_bounds(
current_text, delta_text)
# We have a start of reasoning message, but have not yet finished
# the start of response sequence.
elif not content:
delta_message = self._get_delta_message_with_no_response_bounds(
current_text, reasoning_content, delta_text)
# We've finished both the start of reasoning and start of response seq.
else:
# This should never happen since we matched on the response
assert resp_seq_len is not None
delta_message = self._get_delta_message_with_both_bounds(
delta_text, reasoning_content, content, current_text,
resp_seq_len)
if not delta_message.content and not delta_message.reasoning_content:
return None
return delta_message
#### Implementation details of stream parsing for granite models
def _is_reasoning_start_substr(self, text: str) -> bool:
"""Check if a text matches one of the possible start reasoning seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible reasoning start seqs match.
"""
return any(
think_start.startswith(text)
for think_start in self.valid_think_starts)
def _is_response_start_substr(self, text: str) -> bool:
"""Check if a text matches one of the possible start response seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible response start seqs match.
"""
return any(
response_start.startswith(text)
for response_start in self.valid_response_starts)
def _get_delta_message_with_no_reasoning_bounds(
self,
current_text: str,
delta_text: str,
) -> DeltaMessage:
"""Parse the delta message when the current text has not yet completed
its start of reasoning sequence.
Args:
current_text (str): The full previous + delta text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
prev_longest_length = len(current_text) - len(delta_text)
is_substr = self._is_reasoning_start_substr(current_text)
was_substr = self._is_reasoning_start_substr(
current_text[:prev_longest_length])
# Check if we just generated something NOT in the special token seq;
# if so, add everything that we previously skipped with this delta
# message and append everything to content in the future.
if was_substr and not is_substr:
return DeltaMessage(
reasoning_content=None,
content=current_text,
)
if is_substr:
# Might still be in the special token sequence; return nothing
return DeltaMessage(reasoning_content=None, content=None)
# Otherwise the sequence has already been broken and we already
# corrected; just return the delta text as normal content.
return DeltaMessage(reasoning_content=None, content=delta_text)
def _get_delta_message_with_no_response_bounds(
self,
current_text: str,
reasoning_content: str,
delta_text: str,
) -> DeltaMessage:
"""Parse the delta message when the current text has both reasoning
content with no (response) content. NOTE that we may have overlapping
tokens with the start of reasoning / start of response sequences on
either side of the delta text.
Args:
current_text (str): The full previous + delta text.
reasoning_content (str): reasoning content from current_text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# If we have no reasoning content or explicitly end with the start of
# response sequence, we are in transition to the response; need to be
# careful here, since the final token (:) will match the reasoning
# content and fully parse it out; we should not pass the : back.
ends_with_start_response_seq = any(
current_text.endswith(response_start)
for response_start in self.valid_response_starts)
if reasoning_content is None or ends_with_start_response_seq:
return DeltaMessage(reasoning_content=None, content=None)
# Consider previous / current text only within context of the reasoning
previous_text = reasoning_content[:-len(delta_text)]
current_text = reasoning_content
# We need to be careful about adding unfinished response sequences;
# Find the place at which we MIGHT be starting a response sequence
prev_idx = previous_text.rfind(self.seq_boundary_start)
delta_idx = delta_text.rfind(self.seq_boundary_start)
# Check the state of potential start of response substring matches.
prev_was_substr = self._is_response_start_substr(
previous_text[prev_idx:]) if prev_idx >= 0 else False
delta_continues_substr = self._is_response_start_substr(
current_text[prev_idx:]) if prev_idx >= 0 else False
delta_new_substr = self._is_response_start_substr(
delta_text[delta_idx:]) if delta_idx >= 0 else False
# Delta only contains potential continued response sequence text.
if delta_continues_substr:
return DeltaMessage(reasoning_content=None, content=None)
if not prev_was_substr:
# Delta may be starting a new response seq but has other text too.
if delta_new_substr:
return DeltaMessage(reasoning_content=delta_text[:delta_idx],
content=None)
# Normal case for most reasoning text (no potential special seqs).
return DeltaMessage(reasoning_content=delta_text, content=None)
# The substring that previously seemed to be a potential response
# seq wasn't one; we need to add the content to the delta message,
# and also slice off the potential response sequence
elif delta_new_substr:
reasoning_content = previous_text[
prev_idx:] + delta_text[:delta_idx]
return DeltaMessage(reasoning_content=reasoning_content,
content=None)
# No new substring yet, and we broke our old one; take the whole delta
return DeltaMessage(
reasoning_content=previous_text[prev_idx:] + delta_text,
content=None,
)
def _get_delta_message_with_both_bounds(
self,
delta_text: str,
reasoning_content: str,
response_content: str,
current_text: str,
response_seq_len: int,
) -> DeltaMessage:
"""Parse the delta message when the current text has both reasoning
content and normal (response) content.
Args:
delta_text (str): Text to consider and parse content from.
reasoning_content (str): reasoning content from current_text.
response_content (str): response content from current_text.
current_text (str): The full previous + delta text.
response_seq_len(str): Len of the complete response sequence used.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# Always have content; take length to the end
delta_content = delta_text[-len(response_content):]
reasoning_end_idx = len(delta_text) - (len(response_content) +
response_seq_len)
if reasoning_end_idx < 0:
delta_reasoning_content = None
else:
# Get the starting offset
start_reasoning_content_idx = len(
reasoning_content) + response_seq_len + len(
response_content) - 1
delta_offset = len(current_text) - len(delta_text)
start_offset = start_reasoning_content_idx - delta_offset
if start_offset < 0:
start_offset = 0
delta_reasoning_content = delta_text[
start_offset:reasoning_end_idx]
return DeltaMessage(
reasoning_content=delta_reasoning_content,
content=delta_content,
)
def _get_content_sections(
self, current_text: str
) -> tuple[Optional[str], Optional[int], Optional[str]]:
"""Parse the text to extract the reasoning content / content
if we have them.
Args:
current_text (str): The full previous + delta text.
Returns:
tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
containing the reasoning content, the length of the response seq
(if there is one) and the non-reasoning content.
"""
current_chunk_start = 0
start_reasoning_content = None
parsed_content = False
delimiter_idxs = [
idx for idx, char in enumerate(current_text)
if char == self.seq_boundary_end
]
for current_chunk_end in delimiter_idxs:
current_chunk = current_text[current_chunk_start:current_chunk_end]
# Check to see if the start of reasoning seq if complete
if start_reasoning_content is None:
for think_start in self.valid_think_starts:
if current_chunk == think_start[:-1]:
start_reasoning_content = current_chunk_end + 1
current_chunk_start = current_chunk_end + 1
break
# Check to see if the start of response seq if complete
elif not parsed_content:
for response_start in self.valid_response_starts:
if current_chunk[-len(response_start) +
1:] == response_start[:-1]:
# Mark end of reasoning and start response content
# after the start of response sequence.
end_reasoning_content = current_chunk_end - len(
response_start)
reasoning_content = current_text[
start_reasoning_content:end_reasoning_content]
response_content = current_text[current_chunk_end + 1:]
return reasoning_content, len(
response_start), response_content
if start_reasoning_content and not parsed_content:
return current_text[start_reasoning_content:], None, None
return None, None, None
......@@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
SolarConfig, Telechat2Config,
UltravoxConfig)
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
......@@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
"skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
......@@ -261,6 +262,11 @@ def get_config(
MISTRAL_CONFIG_NAME,
revision=revision):
config_format = ConfigFormat.MISTRAL
else:
raise ValueError(
"Could not detect config format for no config file found. "
"Ensure your model has either config.json (HF format) "
"or params.json (Mistral format).")
except Exception as e:
error_message = (
......@@ -323,7 +329,14 @@ def get_config(
elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, token=HF_TOKEN, **kwargs)
else:
raise ValueError(f"Unsupported config format: {config_format}")
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
]
raise ValueError(
f"Unsupported config format: {config_format}. "
f"Supported formats are: {', '.join(supported_formats)}. "
f"Ensure your model uses one of these configuration formats "
f"or specify the correct format explicitly.")
# Special architecture mapping check for GGUF models
if is_gguf:
......
......@@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
......@@ -42,6 +43,7 @@ __all__ = [
"NemotronConfig",
"NVLM_D_Config",
"Olmo2Config",
"SkyworkR1VChatConfig",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
......
......@@ -5,6 +5,8 @@ from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
class EAGLEConfig(PretrainedConfig):
model_type = "eagle"
......@@ -14,8 +16,17 @@ class EAGLEConfig(PretrainedConfig):
truncated_vocab_size: Optional[int] = None,
**kwargs):
model_config = None if model is None else (AutoConfig.for_model(
**model) if isinstance(model, dict) else model)
model_config: Union[PretrainedConfig, DeepseekV2Config, None]
if isinstance(model, dict):
archs = model.get("architectures", [])
target_archs = ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]
if any(target_arch in archs for target_arch in target_archs):
# AutoConfig does not support DeepSeek MoE models yet
model_config = DeepseekV2Config(**model)
else:
model_config = AutoConfig.for_model(**model)
else:
model_config = model
for k, v in kwargs.items():
if k != "architectures" and k != "model_type" and hasattr(
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class SkyworkR1VChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
......@@ -226,7 +226,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
# 去掉结尾的eos token
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
......
......@@ -124,12 +124,14 @@ def find_tokenizer_file(files: List[str]):
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
raise OSError(f"Found {len(matched_files)} files matching the "
f"pattern: {file_pattern}. Make sure only one Mistral "
raise OSError(
f"Found {len(matched_files)} files matching the "
f"pattern: `{file_pattern.pattern}`. Make sure only one Mistral "
f"tokenizer is present in {files}.")
elif len(matched_files) == 0:
raise OSError(f"Found {len(matched_files)} files matching the "
f"pattern: {file_pattern}. Make sure that a Mistral "
raise OSError(
f"Found {len(matched_files)} files matching the "
f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
f"tokenizer is present in {files}.")
return matched_files[0]
......
# SPDX-License-Identifier: Apache-2.0
from functools import cache
from os import PathLike
from pathlib import Path
from typing import List, Optional, Union
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
from vllm.logger import init_logger
logger = init_logger(__name__)
def is_s3(model_or_path: str) -> bool:
return model_or_path.lower().startswith('s3://')
......@@ -17,9 +23,14 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool:
elif model.suffix == ".gguf":
return True
with open(model, "rb") as f:
try:
with model.open("rb") as f:
header = f.read(4)
return header == b"GGUF"
except Exception as e:
logger.debug("Error reading file %s: %s", model, e)
return False
def modelscope_list_repo_files(
......@@ -38,3 +49,35 @@ def modelscope_list_repo_files(
if file['Type'] == 'blob'
]
return files
@cache
def maybe_model_redirect(model: str) -> str:
"""
Use model_redirect to redirect the model name to a local folder.
:param model: hf model name
:return: maybe redirect to a local folder
"""
model_redirect_path = VLLM_MODEL_REDIRECT_PATH
if not model_redirect_path:
return model
if not Path(model_redirect_path).exists():
return model
with open(model_redirect_path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.split("\t")
if model == model_name:
redirect_name = redirect_name.strip()
logger.info("model redirect: [ %s ] -> [ %s ]", model,
redirect_name)
return redirect_name
except Exception:
pass
return model
......@@ -10,6 +10,7 @@ import datetime
import enum
import gc
import getpass
import hashlib
import importlib
import importlib.metadata
import importlib.util
......@@ -17,6 +18,7 @@ import inspect
import ipaddress
import multiprocessing
import os
import pickle
import re
import signal
import socket
......@@ -31,15 +33,17 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping)
Iterable, Iterator, KeysView, Mapping)
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union)
Optional, Type, TypeVar, Union, cast, overload)
from uuid import uuid4
import cachetools
import cloudpickle
import numpy as np
import numpy.typing as npt
......@@ -57,7 +61,7 @@ import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
import json
......@@ -172,6 +176,7 @@ U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
_T = TypeVar("_T")
class _Sentinel:
......@@ -205,6 +210,19 @@ class Counter:
self.counter = 0
class _MappingOrderCacheView(UserDict[_K, _V]):
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
super().__init__(data)
self.ordered_keys = ordered_keys
def __iter__(self) -> Iterator[_K]:
return iter(self.ordered_keys)
def keys(self) -> KeysView[_K]:
return KeysView(self.ordered_keys)
class CacheInfo(NamedTuple):
hits: int
total: int
......@@ -217,45 +235,62 @@ class CacheInfo(NamedTuple):
return self.hits / self.total
class LRUCache(Generic[_K, _V]):
"""Note: This class is not thread safe!"""
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
def __init__(self,
capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]()
self.capacity = capacity
self._hits = 0
self._total = 0
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __delitem__(self, key: _K) -> None:
run_on_remove = key in self
value = self.__getitem__(key)
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
@property
def cache(self) -> Mapping[_K, _V]:
"""Return the internal cache dictionary in order (read-only)."""
return _MappingOrderCacheView(
self._Cache__data, # type: ignore
self.order)
def __delitem__(self, key: _K) -> None:
self.pop(key)
@property
def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total)
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
self._LRUCache__update(key) # type: ignore
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
@overload
def get(self, key: _K, /) -> Optional[_V]:
...
@overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
...
def get(self,
key: _K,
/,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
self._hits += 1
else:
......@@ -264,60 +299,76 @@ class LRUCache(Generic[_K, _V]):
self._total += 1
return value
@overload
def pop(self, key: _K) -> _V:
...
@overload
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
...
def pop(self,
key: _K,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key not in self:
return default
value = self[key]
del self[key]
return value
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
self.__setitem__(key, value)
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
if key not in self:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: _K) -> None:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
if len(self) == 0:
return
self.popitem(remove_pinned=remove_pinned)
def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
(key for key in self.order if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
lru_key = next(iter(self.order))
value = self.pop(cast(_K, lru_key))
return (lru_key, value)
class PyObjectCache:
......@@ -528,7 +579,7 @@ def get_open_port() -> int:
dp_port = envs.VLLM_DP_MASTER_PORT
while True:
port = _get_open_port()
if port >= dp_port and port < dp_port + 10:
if dp_port <= port < dp_port + 10:
continue
return port
return _get_open_port()
......@@ -745,6 +796,14 @@ def is_pin_memory_available() -> bool:
return current_platform.is_pin_memory_available()
@cache
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return is_pin_memory_available()
class DeviceMemoryProfiler:
def __init__(self, device: Optional[torch.types.Device] = None):
......@@ -1162,7 +1221,7 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'.")
class SortedHelpFormatter(argparse.HelpFormatter):
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def add_arguments(self, actions):
......@@ -1183,6 +1242,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
model_in_cli_args = any(arg == '--model' for arg in args)
if model_in_cli_args:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option.")
if '--config' in args:
args = self._pull_args_from_config(args)
......@@ -1266,19 +1335,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args = self._load_config_file(file_path)
# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0] == "serve":
if index == 1:
model_in_cli = len(args) > 1 and not args[1].startswith('-')
model_in_config = any(arg == '--model' for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model_tag specified! Please check your command-line"
" arguments.")
"No model specified! Please specify model either "
"as a positional argument or in a config file.")
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
else:
# No model in CLI, use config if available
args = [args[0]
] + config_args + args[1:index] + args[index + 2:]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
......@@ -1296,9 +1375,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split('.')[-1]
if extension not in ('yaml', 'yml'):
raise ValueError(
......@@ -1691,18 +1768,21 @@ class ClassRegistry(UserDict[Type[T], _V]):
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if isinstance(tensor, torch.Tensor):
return torch.ops._C.weak_ref_tensor(tensor)
else:
return tensor
def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]:
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
......@@ -1716,6 +1796,14 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
......@@ -2247,11 +2335,11 @@ def make_zmq_socket(
if socket_type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
socket.bind(path)
elif socket_type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
socket.connect(path)
else:
raise ValueError(f"Unknown Socket Type: {socket_type}")
......@@ -2259,7 +2347,11 @@ def make_zmq_socket(
@contextlib.contextmanager
def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
def zmq_socket_ctx(
path: str,
socket_type: Any,
linger: int = 0,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
......@@ -2270,7 +2362,7 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
logger.debug("Got Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
ctx.destroy(linger=linger)
def is_in_ray_actor():
......@@ -2567,3 +2659,33 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
return wrapper
return decorator
# Only relevant for models using ALiBi (e.g, MPT)
def check_use_alibi(model_config: ModelConfig) -> bool:
return (getattr(model_config.hf_text_config, "alibi", False) # Falcon
or ("BloomForCausalLM" in getattr(model_config.hf_config,
"architectures", [])) # Bloom
or getattr(model_config.hf_text_config, "position_encoding_type",
"") == "alibi" # codellm_1b_alibi
or
(hasattr(model_config.hf_text_config, "attn_config") # MPT
and model_config.hf_text_config.attn_config.get("alibi", False)))
def sha256(input) -> int:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
arbitrary Python objects to be used. Note that this function does
not use a hash seed—if you need one, prepend it explicitly to the input.
Args:
input: Any picklable Python object.
Returns:
An integer representing the SHA-256 hash of the serialized input.
"""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
......@@ -96,6 +96,183 @@ class FlashAttentionMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_attn_metadata: Optional[LocalAttentionMetadata] = None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block = np.minimum(
attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size),
q_seqlens).astype(np.int32)
tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size)
local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block,
attn_chunk_size)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange = np.repeat(local_blocks, local_blocks) - arange - 1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local = \
np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks)
# set the first block since this may be a partial block
seqlens_q_local[arange == 0] = q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local[arange > 0] = np.minimum(
seqlens_q_local - attn_chunk_size * (arange - 1),
attn_chunk_size)[arange > 0]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\
.astype(np.int32)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local = np.full(cu_num_blocks[-1],
attn_chunk_size,
dtype=np.int32)
seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block
k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \
(rarange * attn_chunk_size + \
np.repeat(tokens_in_last_block, local_blocks))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, \
f"attn_chunk_size {attn_chunk_size} is not " \
f"divisible by page_size {page_size}"
pages_per_local_batch = attn_chunk_size // page_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices= np.broadcast_to(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
.view(virtual_batches, -1)
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
block_table_local
class FlashAttentionMetadataBuilder:
......@@ -109,18 +286,40 @@ class FlashAttentionMetadataBuilder:
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table,
self.runner.block_size,
)
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(
virt_q_cu_seqlens_np).to(self.runner.device,
non_blocking=True),
local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True),
local_block_table=virt_block_table,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=virt_k_seqlens_np.max(),
)
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
......@@ -149,6 +348,7 @@ class FlashAttentionMetadataBuilder:
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata
......@@ -167,6 +367,7 @@ class FlashAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
......@@ -203,6 +404,7 @@ class FlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
and not flash_attn_supports_fp8():
......@@ -265,8 +467,7 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale,
layer._v_scale,
)
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
key.shape[1])
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
......@@ -278,22 +479,41 @@ class FlashAttentionImpl(AttentionImpl):
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade:
# Regular attention (common case).
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
block_table=block_table,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
......@@ -302,6 +522,8 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
......
......@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend):
......@@ -41,7 +37,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads * head_size)
return (num_blocks, block_size, num_kv_heads * 2, head_size)
@staticmethod
def swap_blocks(
......@@ -92,6 +88,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -99,15 +97,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
......@@ -118,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if tpu_version == 4:
self.vmem_limit_bytes = 16 * 1024 * 1024
else:
self.vmem_limit_bytes = 64 * 1024 * 1024
def forward(
self,
......@@ -132,7 +118,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
......@@ -142,14 +128,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
[num_blocks, block_size, num_kv_heads * head_size])
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# For determine_available_memory case.
if kv_cache[0].numel() == 0:
if kv_cache.numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
......@@ -158,24 +143,28 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0:
if kv_cache.numel() > 0:
slot_mapping = attn_metadata.slot_mapping
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
write_to_kv_cache(key, value, kv_cache, slot_mapping)
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
value_cache,
kv_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
vmem_limit_bytes=self.vmem_limit_bytes,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale)
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
)
return output.reshape(num_tokens, hidden_size)
......@@ -183,8 +172,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
......@@ -192,14 +180,19 @@ def write_to_kv_cache(
Args:
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
_, _, num_combined_kv_heads, head_size = kv_cache.shape
num_kv_heads = num_combined_kv_heads // 2
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
head_size)
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
key_cache = key_cache.flatten(0, 1)
value_cache = value_cache.flatten(0, 1)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
kv_cache = kv_cache.flatten(0, 1)
kv_cache.index_copy_(0, slot_mapping, kv)
......@@ -70,6 +70,7 @@ class TritonAttentionImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
......@@ -86,6 +87,7 @@ class TritonAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
......@@ -156,19 +158,37 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
sequesd_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
sequesd_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=attn_metadata.block_table,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
max_query_len=attn_metadata.max_query_len,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=sequesd_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from typing import Callable, Optional
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
......@@ -27,6 +27,7 @@ class BlockPool:
"""
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
......@@ -50,6 +51,11 @@ class BlockPool:
self.cached_block_hash_to_block: dict[BlockHashType, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
def get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
......@@ -75,6 +81,7 @@ class BlockPool:
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
......@@ -93,6 +100,7 @@ class BlockPool:
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
......@@ -138,7 +146,7 @@ class BlockPool:
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
......@@ -212,7 +220,7 @@ class BlockPool:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
if block.ref_cnt == 0 and block != self.null_block:
self.free_block_queue.remove(block)
block.incr_ref()
......@@ -226,7 +234,8 @@ class BlockPool:
"""
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
# null_block should not be added to the free list.
if block.ref_cnt == 0 and block != self.null_block:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
......@@ -239,10 +248,10 @@ class BlockPool:
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
if num_used_blocks > 0:
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
"blocks (%d) are not freed yet", num_used_blocks - 1)
return False
# Remove all hashes so that no new blocks will hit.
......
......@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
......@@ -67,6 +67,7 @@ class EncoderCacheManager:
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
......@@ -74,6 +75,7 @@ def compute_encoder_budget(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
......@@ -89,7 +91,11 @@ def compute_encoder_budget(
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
......@@ -97,6 +103,7 @@ def compute_encoder_budget(
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
......@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
......@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence.
"""
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)
max_tokens_by_modality_dict = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment