Commit 31f6b24f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori

parents 89d1dd57 25f560a6
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
class Torch25CustomGraphPass(ABC): # noqa (redefinition)
"""
This class replaces CustomGraphPass from torch==2.6 when using torch<2.6.
It conforms to the 2.6 interface but also supports pickling, as that's what
the inductor code cache uses to determine the cache key before 2.6.
(in 2.6 and above, uuid() is used.)
Subclasses can just "pretend" that uuid is used.
"""
@abstractmethod
def __call__(self, graph: torch.fx.graph.Graph) -> None:
"""
Implementation of the custom pass.
"""
@abstractmethod
def uuid(self) -> Optional[Any]:
"""
Return an ID to uniquely identify your custom pass implementation.
Return None to skip inductor code caching entirely.
"""
def __getstate__(self):
"""
Pickling is used instead of uuid() in torch<2.6. Just return uuid()
to enable subclasses to only have to implement uuid.
"""
return self.uuid()
def __setstate__(self, state):
raise ValueError("Cannot unpickle CustomGraphPass because pickling"
" is used for cache key uuid. Use torch>=2.6 with"
" native uuid support for custom passes.")
...@@ -4,6 +4,7 @@ import ast ...@@ -4,6 +4,7 @@ import ast
import copy import copy
import enum import enum
import hashlib import hashlib
import importlib.metadata
import json import json
import sys import sys
import warnings import warnings
...@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, ...@@ -17,6 +18,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, Union) Optional, Protocol, Union)
import torch import torch
from packaging.version import Version
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -52,8 +54,6 @@ if TYPE_CHECKING: ...@@ -52,8 +54,6 @@ if TYPE_CHECKING:
else: else:
QuantizationConfig = None QuantizationConfig = None
from packaging.version import Version
logger = init_logger(__name__) logger = init_logger(__name__)
# This value is chosen to have a balance between ITL and TTFT. Note it is # This value is chosen to have a balance between ITL and TTFT. Note it is
...@@ -1157,10 +1157,6 @@ class CacheConfig: ...@@ -1157,10 +1157,6 @@ class CacheConfig:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
if envs.VLLM_USE_V1:
raise NotImplementedError(
"V1 does not yet support fp8 KV cache. "
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
...@@ -1281,6 +1277,7 @@ class LoadFormat(str, enum.Enum): ...@@ -1281,6 +1277,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes" BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral" MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer" RUNAI_STREAMER = "runai_streamer"
FASTSAFETENSORS = "fastsafetensors"
@dataclass @dataclass
...@@ -2376,12 +2373,6 @@ class LoRAConfig: ...@@ -2376,12 +2373,6 @@ class LoRAConfig:
self.lora_dtype = model_config.dtype self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str): elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype) self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization and model_config.quantization not in [
"awq", "gptq"
]:
# TODO support marlin
logger.warning("%s quantization is not tested with LoRA yet.",
model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/features/compatibility_matrix.md # Reminder: Please update docs/source/features/compatibility_matrix.md
...@@ -2809,12 +2800,17 @@ class DecodingConfig: ...@@ -2809,12 +2800,17 @@ class DecodingConfig:
return hash_str return hash_str
def __post_init__(self): def __post_init__(self):
valid_guided_backends = [ v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance' 'outlines', 'lm-format-enforcer', 'xgrammar'
] ]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
backend = GuidedDecodingParams( backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends
else:
valid_guided_backends = v0_valid_guided_backends
if backend not in valid_guided_backends: if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}'," raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}") f" must be one of {valid_guided_backends}")
...@@ -3092,8 +3088,7 @@ class CompilationConfig(BaseModel): ...@@ -3092,8 +3088,7 @@ class CompilationConfig(BaseModel):
compilation. compilation.
""" """
dict_ = self.model_dump(include={"enable_fusion", "enable_noop"}) dict_ = self.model_dump(include={"enable_fusion", "enable_noop"})
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return InductorPass.hash_dict(dict_)
return hashlib.sha256(encoded).digest()
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if not self.enable_noop and self.enable_fusion: if not self.enable_noop and self.enable_fusion:
...@@ -3182,7 +3177,7 @@ class CompilationConfig(BaseModel): ...@@ -3182,7 +3177,7 @@ class CompilationConfig(BaseModel):
# and it is not yet a priority. RFC here: # and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703 # https://github.com/vllm-project/vllm/issues/14703
if Version(torch.__version__) >= Version("2.6"): if Version(importlib.metadata.version('torch')) >= Version("2.6"):
KEY = 'enable_auto_functionalized_v2' KEY = 'enable_auto_functionalized_v2'
if KEY not in self.inductor_compile_config: if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False self.inductor_compile_config[KEY] = False
......
...@@ -233,6 +233,7 @@ class MessageQueue: ...@@ -233,6 +233,7 @@ class MessageQueue:
if is_valid_ipv6_address(connect_ip): if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1) self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True remote_addr_ipv6 = True
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://*:{remote_subscribe_port}" socket_addr = f"tcp://*:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr) self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
...@@ -356,8 +357,11 @@ class MessageQueue: ...@@ -356,8 +357,11 @@ class MessageQueue:
# if we wait for a long time, log a message # if we wait for a long time, log a message
if (time.monotonic() - start_time if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ", logger.debug(
VLLM_RINGBUFFER_WARNING_INTERVAL) ("No available shared memory broadcast block found"
" in %s second."),
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1 n_warning += 1
# if we time out, raise an exception # if we time out, raise an exception
...@@ -414,8 +418,11 @@ class MessageQueue: ...@@ -414,8 +418,11 @@ class MessageQueue:
# if we wait for a long time, log a message # if we wait for a long time, log a message
if (time.monotonic() - start_time if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ", logger.debug(
VLLM_RINGBUFFER_WARNING_INTERVAL) ("No available shared memory broadcast block found"
"in %s second."),
VLLM_RINGBUFFER_WARNING_INTERVAL,
)
n_warning += 1 n_warning += 1
# if we time out, raise an exception # if we time out, raise an exception
......
...@@ -897,29 +897,22 @@ def initialize_model_parallel( ...@@ -897,29 +897,22 @@ def initialize_model_parallel(
get_world_group().device_group) get_world_group().device_group)
data_parallel_size = 1 data_parallel_size = 1
has_external_dp = False
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
config = get_current_vllm_config() config = get_current_vllm_config()
if config is not None: if config is not None:
if config.parallel_config.world_size != world_size:
# detect external data parallelism.
# dp in vllm means all dp instances need to run together.
# if the world size does not match, it means this dp is external,
# and the dp instances can run independently, e.g. in rlhf workflow
# from https://github.com/volcengine/verl .
# in that case, we treat the rest dimensions as if they are
# data parallel, and create a dummy dp group that is not used.
data_parallel_size = world_size // (pipeline_model_parallel_size *
tensor_model_parallel_size)
has_external_dp = True
else:
data_parallel_size = config.parallel_config.data_parallel_size data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: DP x PP x TP # the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
# DP is the data parallel group that is part of the model,
# all the ranks in the same DP group should generate simultaneously,
# i.e. the `generate` call in the same DP group should be called together,
# otherwise it will cause deadlock.
# to get group_ranks for each dimension, transpose that dimension to the # to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension # last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape( all_ranks = torch.arange(world_size).reshape(
data_parallel_size, pipeline_model_parallel_size, -1, data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa tensor_model_parallel_size) # noqa
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
...@@ -939,7 +932,7 @@ def initialize_model_parallel( ...@@ -939,7 +932,7 @@ def initialize_model_parallel(
global _PP global _PP
assert _PP is None, ( assert _PP is None, (
"pipeline model parallel group is already initialized") "pipeline model parallel group is already initialized")
group_ranks = all_ranks.transpose(1, 2).reshape( group_ranks = all_ranks.transpose(2, 3).reshape(
-1, pipeline_model_parallel_size).unbind(0) -1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(group_ranks, _PP = init_model_parallel_group(group_ranks,
...@@ -949,16 +942,10 @@ def initialize_model_parallel( ...@@ -949,16 +942,10 @@ def initialize_model_parallel(
global _DP global _DP
assert _DP is None, ("data parallel group is already initialized") assert _DP is None, ("data parallel group is already initialized")
group_ranks = all_ranks.transpose(0, group_ranks = all_ranks.transpose(1,
2).reshape(-1, 3).reshape(-1,
data_parallel_size).unbind(0) data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
if has_external_dp:
# create a dummy dp group that is not used actually,
# since this dp is external.
# a dummy dp group means every rank is a group itself.
# this way, no communication is needed, no memory is wasted.
group_ranks = [[x] for x in range(world_size)]
_DP = init_model_parallel_group(group_ranks, _DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank, get_world_group().local_rank,
backend, backend,
......
...@@ -391,16 +391,13 @@ class EngineArgs: ...@@ -391,16 +391,13 @@ class EngineArgs:
default='xgrammar', default='xgrammar',
help='Which engine will be used for guided decoding' help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support ' ' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines, ' 'https://github.com/mlc-ai/xgrammar and '
'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/guidance-ai/llguidance.'
'https://github.com/noamgat/lm-format-enforcer.' 'Valid backend values are "xgrammar", "guidance", and "auto". '
' Can be overridden per request via guided_decoding_backend' 'With "auto", we will make opinionated choices based on request'
' parameter.\n' 'contents and what the backend libraries currently support, so '
'Backend-specific options can be supplied in a comma-separated ' 'the behavior is subject to change in each release. '
'list following a colon after the backend name. Valid backends and ' 'The default is xgrammar.')
'all available options are: [xgrammar:no-fallback, '
'xgrammar:disable-any-whitespace, '
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=nullable_str, type=nullable_str,
...@@ -1539,9 +1536,9 @@ class EngineArgs: ...@@ -1539,9 +1536,9 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# Only support Xgrammar for guided decoding so far. # Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [ SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace" "xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
] ]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING: if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend", _raise_or_fallback(feature_name="--guided-decoding-backend",
...@@ -1562,6 +1559,17 @@ class EngineArgs: ...@@ -1562,6 +1559,17 @@ class EngineArgs:
# No Fp8 KV cache so far. # No Fp8 KV cache so far.
if self.kv_cache_dtype != "auto": if self.kv_cache_dtype != "auto":
fp8_attention = self.kv_cache_dtype.startswith("fp8")
will_use_fa = (
current_platform.is_cuda()
and not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:
_raise_or_fallback(feature_name="--kv-cache-dtype", _raise_or_fallback(feature_name="--kv-cache-dtype",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
......
...@@ -545,7 +545,7 @@ async def build_guided_decoding_logits_processor_async( ...@@ -545,7 +545,7 @@ async def build_guided_decoding_logits_processor_async(
sampling_params = copy.copy(sampling_params) sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding guided_decoding = sampling_params.guided_decoding
logger.info( logger.debug(
"Building guided decoding logits processor. " "Building guided decoding logits processor. "
"guided_decoding: %s%s", guided_decoding, "guided_decoding: %s%s", guided_decoding,
f", reasoning_backend: {reasoning_backend}" f", reasoning_backend: {reasoning_backend}"
......
...@@ -1249,7 +1249,7 @@ class LLMEngine: ...@@ -1249,7 +1249,7 @@ class LLMEngine:
return None return None
def _advance_to_next_step( def _advance_to_next_step(
self, output: List[SamplerOutput], self, output: SamplerOutput,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None:
"""Given model output from a single run, append the tokens to the """Given model output from a single run, append the tokens to the
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import codecs
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
...@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import ( ...@@ -30,7 +29,8 @@ from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio) InputAudio)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
from typing_extensions import Required, TypeAlias, TypedDict from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -306,24 +306,63 @@ def _detect_content_format( ...@@ -306,24 +306,63 @@ def _detect_content_format(
return "openai" return "openai"
def _resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None:
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path, exc_info=True)
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
return None
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template hf_chat_template = _resolve_hf_chat_template(
else: tokenizer,
tokenizer_chat_template = None chat_template=chat_template,
trust_remote_code=trust_remote_code,
jinja_text: Optional[str] tools=tools,
if isinstance(tokenizer_chat_template, str) and chat_template is None: )
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
else: else:
jinja_text = load_chat_template(chat_template, is_literal=True) hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))
detected_format = ("string" if jinja_text is None else detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string")) _detect_content_format(jinja_text, default="string"))
...@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format( ...@@ -332,17 +371,11 @@ def _resolve_chat_template_content_format(
@lru_cache @lru_cache
def resolve_chat_template_content_format( def _log_chat_template_content_format(
chat_template: Optional[str], chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, detected_format: ChatTemplateContentFormatOption,
) -> _ChatTemplateContentFormat: ):
detected_format = _resolve_chat_template_content_format(
chat_template,
given_format,
tokenizer,
)
logger.info( logger.info(
"Detected the chat template content format to be '%s'. " "Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.", "You can set `--chat-template-content-format` to override this.",
...@@ -360,6 +393,29 @@ def resolve_chat_template_content_format( ...@@ -360,6 +393,29 @@ def resolve_chat_template_content_format(
detected_format, detected_format,
) )
def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format return detected_format
...@@ -500,11 +556,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): ...@@ -500,11 +556,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
raise ValueError(\ raise ValueError(\
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
...@@ -533,11 +589,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): ...@@ -533,11 +589,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
raise ValueError( raise ValueError(
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
...@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): ...@@ -711,7 +767,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
f"{type(chat_template)} is not a valid chat template type") f"{type(chat_template)} is not a valid chat template type")
def load_chat_template( def _load_chat_template(
chat_template: Optional[Union[Path, str]], chat_template: Optional[Union[Path, str]],
*, *,
is_literal: bool = False, is_literal: bool = False,
...@@ -724,7 +780,7 @@ def load_chat_template( ...@@ -724,7 +780,7 @@ def load_chat_template(
raise TypeError("chat_template is expected to be read directly " raise TypeError("chat_template is expected to be read directly "
"from its value") "from its value")
return codecs.decode(chat_template, "unicode_escape") return chat_template
try: try:
with open(chat_template) as f: with open(chat_template) as f:
...@@ -742,7 +798,18 @@ def load_chat_template( ...@@ -742,7 +798,18 @@ def load_chat_template(
# If opening a file fails, set chat template to be args to # If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly # ensure we decode so our escape are interpreted correctly
return load_chat_template(chat_template, is_literal=True) return _load_chat_template(chat_template, is_literal=True)
_cached_load_chat_template = lru_cache(_load_chat_template)
def load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
# TODO: Let user specify how to insert multimodal tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
...@@ -1067,23 +1134,20 @@ def apply_hf_chat_template( ...@@ -1067,23 +1134,20 @@ def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*, *,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if chat_template is None: hf_chat_template = _resolve_hf_chat_template(
chat_template = tokenizer.chat_template tokenizer,
chat_template=chat_template,
# FIXME: Temporary workaround for tools=tools,
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31 trust_remote_code=trust_remote_code,
if chat_template is None: )
try:
processor = cached_get_processor(tokenizer.name_or_path)
chat_template = processor.chat_template
except Exception:
pass
if chat_template is None: if hf_chat_template is None:
raise ValueError( raise ValueError(
"As of transformers v4.44, default chat template is no longer " "As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
...@@ -1091,7 +1155,8 @@ def apply_hf_chat_template( ...@@ -1091,7 +1155,8 @@ def apply_hf_chat_template(
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template, tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **kwargs,
) )
...@@ -1100,7 +1165,8 @@ def apply_hf_chat_template( ...@@ -1100,7 +1165,8 @@ def apply_hf_chat_template(
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None, chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
**kwargs: Any, **kwargs: Any,
) -> list[int]: ) -> list[int]:
if chat_template is not None: if chat_template is not None:
...@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template( ...@@ -1117,5 +1183,6 @@ def apply_mistral_chat_template(
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
messages=messages, messages=messages,
tools=tools,
**kwargs, **kwargs,
) )
...@@ -690,8 +690,10 @@ class LLM: ...@@ -690,8 +690,10 @@ class LLM:
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tools,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
prompts: list[Union[TokensPrompt, TextPrompt]] = [] prompts: list[Union[TokensPrompt, TextPrompt]] = []
...@@ -713,18 +715,19 @@ class LLM: ...@@ -713,18 +715,19 @@ class LLM:
tokenizer, tokenizer,
messages=msgs, messages=msgs,
chat_template=chat_template, chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools,
) )
else: else:
prompt_data = apply_hf_chat_template( prompt_data = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
tools=tools,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message, continue_final_message=continue_final_message,
tools=tools,
) )
prompt: Union[TokensPrompt, TextPrompt] prompt: Union[TokensPrompt, TextPrompt]
......
...@@ -379,14 +379,18 @@ class OpenAIServing: ...@@ -379,14 +379,18 @@ class OpenAIServing:
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[TokensPrompt]]: list[TokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
chat_template, chat_template,
tool_dicts,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future = parse_chat_messages_futures(
messages, messages,
self.model_config, model_config,
tokenizer, tokenizer,
content_format=resolved_content_format, content_format=resolved_content_format,
) )
...@@ -410,6 +414,7 @@ class OpenAIServing: ...@@ -410,6 +414,7 @@ class OpenAIServing:
else: else:
request_prompt = apply_hf_chat_template( request_prompt = apply_hf_chat_template(
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
**_chat_template_kwargs, **_chat_template_kwargs,
) )
......
...@@ -75,6 +75,7 @@ if TYPE_CHECKING: ...@@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False VLLM_DISABLE_COMPILE_CACHE: bool = False
...@@ -294,7 +295,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -294,7 +295,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# this is used for configuring the default logging level # this is used for configuring the default logging level
"VLLM_LOGGING_LEVEL": "VLLM_LOGGING_LEVEL":
lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"), lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(),
# if set, VLLM_LOGGING_PREFIX will be prepended to all log messages # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
"VLLM_LOGGING_PREFIX": "VLLM_LOGGING_PREFIX":
...@@ -340,7 +341,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -340,7 +341,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space. # (CPU backend only) CPU key-value cache space.
# default is 4GB # default is 4 GiB
"VLLM_CPU_KVCACHE_SPACE": "VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),
...@@ -412,9 +413,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -412,9 +413,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Cache size (in GiB) for multimodal input cache # Cache size (in GiB) for multimodal input cache
# Default is 8GiB # Default is 4 GiB
"VLLM_MM_INPUT_CACHE_GIB": "VLLM_MM_INPUT_CACHE_GIB":
lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "8")), lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
...@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_FP8_PADDING": "VLLM_ROCM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Pad the weights for the moe kernel
"VLLM_ROCM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache # Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT": "Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")),
......
...@@ -289,16 +289,14 @@ def initialize_ray_cluster( ...@@ -289,16 +289,14 @@ def initialize_ray_cluster(
elif current_platform.is_rocm() or current_platform.is_xpu(): elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found # Try to connect existing ray instance and create a new one if not found
try: try:
ray.init("auto", ignore_reinit_error=True) ray.init("auto")
except ConnectionError: except ConnectionError:
logger.warning( logger.warning(
"No existing RAY instance detected. " "No existing RAY instance detected. "
"A new instance will be launched with current node resources.") "A new instance will be launched with current node resources.")
ray.init(address=ray_address, ray.init(address=ray_address, num_gpus=parallel_config.world_size)
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else: else:
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address)
device_str = current_platform.ray_device_key device_str = current_platform.ray_device_key
if not device_str: if not device_str:
......
...@@ -78,10 +78,6 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -78,10 +78,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
...], scale: float, **kwargs): ...], scale: float, **kwargs):
""" """
Performs GEMM for multiple slices of lora_a. Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics: Semantics:
for i in range(len(lora_a_stacked)): for i in range(len(lora_a_stacked)):
...@@ -226,7 +222,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -226,7 +222,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if buffer is None: if buffer is None:
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to: # We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387 # https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r), (len(output_slices), x.size(0), r),
...@@ -268,16 +264,16 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -268,16 +264,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
y (torch.Tensor): Output tensor. y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor. x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights. lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights. lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor. scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None. buffer (Optional[torch.Tensor]): Default to None.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1) r = lora_b_stacked.size(-1)
if buffer is None: if buffer is None:
# We set the buffer to be float32 by default ,refer to: # We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387 # https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
......
...@@ -815,7 +815,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -815,7 +815,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1] if not use_nn_moe else B.shape[2], B.shape[1] if not use_nn_moe else B.shape[2],
A.shape[1], A.shape[2],
EM, EM,
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
...@@ -1355,8 +1355,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1355,8 +1355,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [ assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
......
...@@ -6,6 +6,7 @@ from enum import Enum ...@@ -6,6 +6,7 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter from torch.nn.parameter import UninitializedParameter
from vllm import envs from vllm import envs
...@@ -111,9 +112,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -111,9 +112,27 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
and weight.stride(-1) == 1
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
return weight
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -233,6 +252,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -233,6 +252,34 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias, e_score_correction_bias,
) )
def forward_hpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
if e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for HPU.")
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k)
def forward_tpu( def forward_tpu(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -432,6 +479,9 @@ class FusedMoE(torch.nn.Module): ...@@ -432,6 +479,9 @@ class FusedMoE(torch.nn.Module):
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.") "non-grouped topk.")
if current_platform.is_hpu():
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
......
...@@ -155,12 +155,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase): ...@@ -155,12 +155,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
def __init__(self, quant_config: BitsAndBytesConfig): def __init__(self, quant_config: BitsAndBytesConfig):
try: try:
import bitsandbytes import bitsandbytes
if bitsandbytes.__version__ < "0.45.0": if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.0.") "install bitsandbytes>=0.45.3.")
except ImportError as err: except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.0 via " raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.0` to use " "`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -255,7 +255,7 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
layer.register_parameter("input_scale", None) layer.register_parameter("input_scale", None)
def add_padding_to_weight(self, weight: torch.Tensor) -> torch.Tensor: def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
# Pad the weight tensor. This is an optimization on ROCm platform, which # Pad the weight tensor. This is an optimization on ROCm platform, which
# can benefit from tensors located far enough from one another in memory # can benefit from tensors located far enough from one another in memory
if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm() if (envs.VLLM_ROCM_FP8_PADDING and current_platform.is_rocm()
...@@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -279,7 +279,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight.data weight = layer.weight.data
weight_scale_inv = layer.weight_scale_inv.data weight_scale_inv = layer.weight_scale_inv.data
weight = self.add_padding_to_weight(weight) weight = self._maybe_pad_weight(weight)
# Torch.compile cannot use Parameter subclasses. # Torch.compile cannot use Parameter subclasses.
layer.weight = Parameter(weight, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False)
...@@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -343,7 +343,7 @@ class Fp8LinearMethod(LinearMethodBase):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
weight = self.add_padding_to_weight(weight) weight = self._maybe_pad_weight(weight)
# Update layer with new values. # Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
......
...@@ -1187,7 +1187,8 @@ def _build_sampler_output( ...@@ -1187,7 +1187,8 @@ def _build_sampler_output(
deferred_sample_results_args=deferred_sample_results_args) deferred_sample_results_args=deferred_sample_results_args)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: def _get_next_prompt_tokens(
seq_group: SequenceGroupToSample) -> tuple[int, ...]:
"""Get a list of next prompt tokens to compute logprob from a """Get a list of next prompt tokens to compute logprob from a
given sequence group. given sequence group.
......
...@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping, ...@@ -49,9 +49,10 @@ from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
get_gguf_extra_tensor_names, get_lock, gguf_quant_weights_iterator, filter_files_not_needed_for_inference, get_gguf_extra_tensor_names,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, get_lock, gguf_quant_weights_iterator, initialize_dummy_weights,
np_cache_weights_iterator, pt_weights_iterator,
runai_safetensors_weights_iterator, safetensors_weights_iterator) runai_safetensors_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -275,7 +276,8 @@ class DefaultModelLoader(BaseModelLoader):
# Some quantized models use .pt files for storing the weights. # Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO: if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"] allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == LoadFormat.SAFETENSORS: elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True use_safetensors = True
allow_patterns = ["*.safetensors"] allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL: elif load_format == LoadFormat.MISTRAL:
...@@ -357,6 +359,12 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -357,6 +359,12 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
) )
elif use_safetensors: elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator( weights_iterator = safetensors_weights_iterator(
hf_weights_files, hf_weights_files,
self.load_config.use_tqdm_on_load, self.load_config.use_tqdm_on_load,
...@@ -379,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -379,6 +387,16 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
elif current_platform.is_hpu():
import habana_frameworks.torch.core as htcore
def _hpu_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
htcore.mark_step()
weights_iterator = _hpu_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0: if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter() self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix. # Apply the prefix.
...@@ -862,12 +880,12 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -862,12 +880,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
try: try:
import bitsandbytes import bitsandbytes
if bitsandbytes.__version__ < "0.45.0": if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.0.") "install bitsandbytes>=0.45.3.")
except ImportError as err: except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.0 via " raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.0` to use " "`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights( hf_weights_files, use_safetensors = self._prepare_weights(
......
...@@ -32,7 +32,7 @@ def set_default_torch_dtype(dtype: torch.dtype): ...@@ -32,7 +32,7 @@ def set_default_torch_dtype(dtype: torch.dtype):
def is_transformers_impl_compatible( def is_transformers_impl_compatible(
arch: str, arch: str,
module: Optional[transformers.PreTrainedModel] = None) -> bool: module: Optional["transformers.PreTrainedModel"] = None) -> bool:
mod = module or getattr(transformers, arch, None) mod = module or getattr(transformers, arch, None)
if mod is None: if mod is None:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment