Commit ad58e9b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev

parents 408f663a 9ba0817f
......@@ -90,12 +90,12 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
......
......@@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
......
......@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [
image_to_pixel_values(img,
image_size,
......@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
max_num,
use_thumbnail=use_thumbnail) for img in data
]
data = torch.stack(data)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
......@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)),
flatten_bn(flatten_bn(pixel_values), concat=True)),
)
raise AssertionError("This line should be unreachable.")
......
......@@ -600,7 +600,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
weight_loader(
param,
loaded_weight,
weight_name,
name,
shard_id=shard_id,
expert_id=expert_id,
)
......
import math
from array import array
from dataclasses import dataclass, fields
from itertools import tee
......@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
......@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
mm_config = ctx.model_config.multimodal_config
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1)
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img
# approximate image size
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size)
mm_config = ctx.model_config.multimodal_config
num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)
tokens = mm_encoder(img_chunk).tokens
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens)
image_feature_size = (size**2) // (patch_size**2)
num_image_tokens = image_feature_size * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)
seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]}
mm_data = {"image": num_images * [image]}
return seq_data, mm_data
......@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images})
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: Optional[List[torch.Tensor]],
image_id: int) -> torch.Tensor:
text_locations = input_ids != image_id
image_locations = input_ids == image_id
seq_len = input_ids.shape[0]
def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is not None and "image" in multi_modal_data:
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode)
N_txt = text_locations.sum().item()
_, D_txt = inputs_embeds.shape
N_img, D_img = image_features.shape
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal "
"to image features dim {D_img}")
assert (seq_len == N_txt +
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img "
f"{(N_txt, N_img, image_locations.sum().item())}")
if image_token_id not in llm_inputs['prompt_token_ids']:
raise ValueError(
(f"You've passed {llm_inputs=} without {image_token_id=}"
" Make sure to process your input via mistral_common's"
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."))
inputs_embeds[image_locations, :] = image_features
return inputs_embeds
return llm_inputs
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
......@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
return None
if isinstance(images, torch.Tensor):
# always take last images
images = [images[-1][i] for i in range(images.size(1))]
# if passed as batch take all images
N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# always take last images
images = [images[-1][i] for i in range(len(images[0]))]
# if passed as list flatten lists of tensors
flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
return images
......
......@@ -50,6 +50,7 @@ from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops
......@@ -697,9 +698,12 @@ def input_processor_for_qwen(ctx: InputContext,
raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0]
else:
# TODO - handle multiple image inputs once the API is solidified
elif isinstance(image_data, Image.Image):
num_images = 1
elif is_list_of(image_data, Image.Image):
num_images = len(image_data)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
......@@ -780,11 +784,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]")
pixel_values = data
else:
transform = build_normalization_transform(image_size)
# TODO - handle multiple image inputs once the API is solidified
transformed_images = [transform(data)]
if not isinstance(data, (list, tuple)):
data = [data]
transformed_images = [transform(datum) for datum in data]
pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values})
......
......@@ -1055,6 +1055,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -1078,6 +1081,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
......
......@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
......@@ -92,7 +93,7 @@ class RequestOutput:
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
......@@ -113,19 +114,26 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
if seq_group.sampling_params is None:
def from_seq_group(cls,
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = seq_group.sampling_params.n
if seq_group.sampling_params.use_beam_search:
n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
......@@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.data._output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]
include_logprobs = sampling_params.logprobs is not None
text_buffer_length = sampling_params.output_text_buffer_length
delta = sampling_params.output_kind == RequestOutputKind.DELTA
outputs = []
include_prompt = True
for seq in top_n_seqs:
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta)
output_logprobs = seq.output_logprobs if include_logprobs else None
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
include_prompt = False
outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
# Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
......
import logging
from typing import Callable, Optional, Union
import vllm.envs as envs
......@@ -29,3 +30,15 @@ def load_general_plugins():
except Exception:
logger.exception("Failed to load general plugin: %s",
plugin.name)
_torch_compile_backend: Optional[Union[Callable, str]] = None
def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
"""Sampling parameters for text generation."""
import copy
from enum import IntEnum
from enum import Enum, IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union
......@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
......@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input.
# They are set in post_init.
......@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
) -> "SamplingParams":
return SamplingParams(
n=1 if n is None else n,
......@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
)
def __post_init__(self) -> None:
......@@ -317,6 +329,9 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
......
......@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
import msgspec
import torch
......@@ -407,6 +408,10 @@ class Sequence:
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_text_offset: int = 0
# Used for incremental detokenization
self.prefix_offset = 0
self.read_offset = 0
......@@ -462,11 +467,37 @@ class Sequence:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int):
def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
......
......@@ -4,7 +4,9 @@ import json
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
from huggingface_hub import file_exists, hf_hub_download
import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download,
try_to_load_from_cache)
from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import (
get_image_processor_config)
......@@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if Path(model).exists():
return (Path(model) / config_name).is_file()
return file_exists(model, config_name, revision=revision, token=token)
# Offline mode support: Check if config file is cached already
cached_filepath = try_to_load_from_cache(repo_id=model,
filename=config_name,
revision=revision)
if isinstance(cached_filepath, str):
# The config file exists in cache- we can continue trying to load
return True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
return False
def get_config(
......@@ -102,6 +119,15 @@ def get_config(
token=kwargs.get("token")):
config_format = ConfigFormat.MISTRAL
else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF:
......@@ -206,6 +232,8 @@ def load_params_config(model, revision) -> PretrainedConfig:
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
......
......@@ -82,6 +82,9 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"decoder models.")
STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
"encoder/decoder models.")
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
......@@ -97,6 +100,7 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
"STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
}
# Constants related to forcing the attention backend selection
......
......@@ -2,6 +2,7 @@ import warnings
try:
import vllm.commit_id
__commit__ = vllm.commit_id.__commit__
except Exception as e:
warnings.warn(f"Failed to read commit hash:\n{e}",
......@@ -9,4 +10,4 @@ except Exception as e:
stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.6.1"
__version__ = "0.6.1.post2"
......@@ -15,7 +15,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
......@@ -121,6 +121,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# Lazy initialization.
self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
......
......@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
_init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
......@@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager")
backend=backend)
def save_sharded_state(
self,
......@@ -1489,6 +1491,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine=virtual_engine)
@torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
......
import dataclasses
import pickle
from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
......@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
return tensor_dict
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
exclude_kwargs: Optional[List[str]] = None):
def _inner(func):
@wraps(func)
def _wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
for k, v in kwargs.items()
if k not in (exclude_kwargs or [])
}
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
pickle.dump(dumped_inputs, filep)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err
return _wrapper
return _inner
class BroadcastableModelInput(ABC):
@abstractmethod
......
......@@ -4,13 +4,6 @@ from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
import torch
from vllm.distributed import get_pp_group
......@@ -36,6 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
def seq_output_builder():
return SequenceOutput(
......@@ -230,12 +225,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._base_model_runner: GPUModelRunnerBase = base_model_runner
self.is_multi_step = self.scheduler_config.is_multi_step
# used to copy tensors from GPU to CPU asynchronously
self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache()
@functools.cached_property
def _copy_stream(self):
# used to copy tensors from GPU to CPU asynchronously
return torch.cuda.Stream()
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
......@@ -486,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
raise ValueError(
f"Multi-step not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs
num_queries = model_input.num_queries
assert num_seqs > 0
assert num_queries > 0
assert num_seqs >= num_queries
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
assert attn_metadata is not None
attn_metadata.advance_step(
frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
num_seqs, num_queries)
if frozen_model_input.seq_lens is not None:
for i in range(num_queries):
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
sampled_token_ids,
self.block_size,
num_seqs,
num_queries,
)
return model_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment