"vllm/vscode:/vscode.git/clone" did not exist on "978a4462bbc529ff204647543526e4caa08ed974"
Commit ad58e9b3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev

parents 408f663a 9ba0817f
...@@ -90,12 +90,12 @@ _MULTIMODAL_MODELS = { ...@@ -90,12 +90,12 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"), "PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"), "PixtralForConditionalGeneration"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"), "Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
} }
_CONDITIONAL_GENERATION_MODELS = { _CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"), "BartModel": ("bart", "BartForConditionalGeneration"),
......
...@@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer. # Gemma does not apply LoRA to the embedding layer.
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__( def __init__(
self, self,
......
...@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
# Add an N dimension for number of images per prompt (currently 1). # Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0) data = data.unsqueeze(0)
elif is_list_of(data, Image.Image): elif is_list_of(data, Image.Image):
# we can't stack here because the images may have different num_patches
data = [ data = [
image_to_pixel_values(img, image_to_pixel_values(img,
image_size, image_size,
...@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): ...@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
max_num, max_num,
use_thumbnail=use_thumbnail) for img in data use_thumbnail=use_thumbnail) for img in data
] ]
data = torch.stack(data)
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True) trust_remote_code=True)
...@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
data=self._validate_pixel_values( data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True).flatten(0, 1)), flatten_bn(flatten_bn(pixel_values), concat=True)),
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
......
...@@ -600,7 +600,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -600,7 +600,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
weight_loader( weight_loader(
param, param,
loaded_weight, loaded_weight,
weight_name, name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
......
import math
from array import array from array import array
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from itertools import tee from itertools import tee
...@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask ...@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
...@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer, ctx.model_config.tokenizer,
tokenizer_mode=ctx.model_config.tokenizer_mode) tokenizer_mode=ctx.model_config.tokenizer_mode)
mm_encoder = tokenizer.instruct.mm_encoder
mm_config = ctx.model_config.multimodal_config mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) patch_size = mm_encoder.mm_config.image_patch_size
image_token_id = mm_encoder.special_ids.img
# approximate image size mm_config = ctx.model_config.multimodal_config
size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size
size = 256
image = Image.new("RGB", (size, size), color=0) image = Image.new("RGB", (size, size), color=0)
img_chunk = ImageChunk(image=image)
tokens = mm_encoder(img_chunk).tokens image_feature_size = (size**2) // (patch_size**2)
token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE,
tokens) num_image_tokens = image_feature_size * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * num_image_tokens
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - num_image_tokens)
seq_data = SequenceData(token_ids) seq_data = SequenceData(token_ids)
mm_data = {"image": max_num_images_per_request * [image]} mm_data = {"image": num_images * [image]}
return seq_data, mm_data return seq_data, mm_data
...@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext, ...@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return MultiModalInputs({"images": images}) return MultiModalInputs({"images": images})
def merge_multimodal_embeddings(input_ids: torch.Tensor, def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
inputs_embeds: torch.Tensor, multi_modal_data = llm_inputs.get("multi_modal_data")
image_features: Optional[List[torch.Tensor]], if multi_modal_data is not None and "image" in multi_modal_data:
image_id: int) -> torch.Tensor: tokenizer = cached_get_tokenizer(
text_locations = input_ids != image_id ctx.model_config.tokenizer,
image_locations = input_ids == image_id tokenizer_mode=ctx.model_config.tokenizer_mode)
seq_len = input_ids.shape[0]
N_txt = text_locations.sum().item() mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
_, D_txt = inputs_embeds.shape image_token_id = mm_encoder.special_ids.img
N_img, D_img = image_features.shape
assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " if image_token_id not in llm_inputs['prompt_token_ids']:
"to image features dim {D_img}") raise ValueError(
assert (seq_len == N_txt + (f"You've passed {llm_inputs=} without {image_token_id=}"
N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " " Make sure to process your input via mistral_common's"
f"{(N_txt, N_img, image_locations.sum().item())}") " tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."))
inputs_embeds[image_locations, :] = image_features return llm_inputs
return inputs_embeds
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self, def __init__(self,
...@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
return None return None
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
# always take last images # if passed as batch take all images
images = [images[-1][i] for i in range(images.size(1))] N, B, C, W, H = images.shape
images = images.reshape(N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list): elif isinstance(images, list):
# always take last images # if passed as list flatten lists of tensors
images = [images[-1][i] for i in range(len(images[0]))] flatten_images = []
for imgs_per_req in images:
imgs_per_req = [
imgs_per_req[i] for i in range(imgs_per_req.size(0))
] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req
flatten_images.extend(imgs_per_req)
images = flatten_images
return images return images
......
...@@ -50,6 +50,7 @@ from vllm.multimodal.base import MultiModalInputs ...@@ -50,6 +50,7 @@ from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers from .utils import flatten_bn, is_pp_missing_parameter, make_layers
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -697,9 +698,12 @@ def input_processor_for_qwen(ctx: InputContext, ...@@ -697,9 +698,12 @@ def input_processor_for_qwen(ctx: InputContext,
raise ValueError( raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}") f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0] num_images = 1 if num_dims == 2 else image_data.shape[0]
else: elif isinstance(image_data, Image.Image):
# TODO - handle multiple image inputs once the API is solidified
num_images = 1 num_images = 1
elif is_list_of(image_data, Image.Image):
num_images = len(image_data)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if prompt is None: if prompt is None:
prompt = tokenizer.decode(prompt_token_ids) prompt = tokenizer.decode(prompt_token_ids)
...@@ -780,11 +784,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: ...@@ -780,11 +784,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]") f"received shape [{data.shape}]")
pixel_values = data pixel_values = data
else: else:
transform = build_normalization_transform(image_size) transform = build_normalization_transform(image_size)
# TODO - handle multiple image inputs once the API is solidified if not isinstance(data, (list, tuple)):
transformed_images = [transform(data)] data = [data]
transformed_images = [transform(datum) for datum in data]
pixel_values = torch.stack(transformed_images, dim=0) pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalInputs({"pixel_values": pixel_values}) return MultiModalInputs({"pixel_values": pixel_values})
......
...@@ -1055,6 +1055,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1055,6 +1055,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -1078,6 +1081,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1078,6 +1081,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1) loaded_weight = loaded_weight.reshape(-1)
try: try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
except KeyError: except KeyError:
print(params_dict.keys()) print(params_dict.keys())
......
...@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence ...@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union from typing import Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus) SequenceGroup, SequenceStatus)
...@@ -92,7 +93,7 @@ class RequestOutput: ...@@ -92,7 +93,7 @@ class RequestOutput:
self, self,
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: Optional[List[int]],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool, finished: bool,
...@@ -113,19 +114,26 @@ class RequestOutput: ...@@ -113,19 +114,26 @@ class RequestOutput:
self.encoder_prompt_token_ids = encoder_prompt_token_ids self.encoder_prompt_token_ids = encoder_prompt_token_ids
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls,
if seq_group.sampling_params is None: seq_group: SequenceGroup) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError( raise ValueError(
"Sampling parameters are missing for a CompletionRequest.") "Sampling parameters are missing for a CompletionRequest.")
finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
seqs = seq_group.get_seqs() seqs = seq_group.get_seqs()
if len(seqs) == 1: if len(seqs) == 1:
top_n_seqs = seqs top_n_seqs = seqs
else: else:
# Get the top-n sequences. # Get the top-n sequences.
n = seq_group.sampling_params.n n = sampling_params.n
if seq_group.sampling_params.use_beam_search: if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score( sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty) sampling_params.length_penalty)
else: else:
sorting_key = lambda seq: seq.get_cumulative_logprob() sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
...@@ -135,26 +143,49 @@ class RequestOutput: ...@@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence # NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the # always has the logprobs of the sampled tokens even if the
# logprobs are not requested. # logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None include_logprobs = sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length text_buffer_length = sampling_params.output_text_buffer_length
outputs = [ delta = sampling_params.output_kind == RequestOutputKind.DELTA
CompletionOutput(
seqs.index(seq), outputs = []
seq.get_output_text_to_return(text_buffer_length), include_prompt = True
seq.data._output_token_ids, for seq in top_n_seqs:
seq.get_cumulative_logprob() if include_logprobs else None, output_text = seq.get_output_text_to_return(
seq.output_logprobs if include_logprobs else None, text_buffer_length, delta)
SequenceStatus.get_finished_reason(seq.status), output_token_ids = seq.get_output_token_ids_to_return(delta)
seq.stop_reason) for seq in top_n_seqs output_logprobs = seq.output_logprobs if include_logprobs else None
]
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
include_prompt = False
outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = seq_group.prompt if include_prompt:
prompt_token_ids = seq_group.prompt_token_ids prompt = seq_group.prompt
encoder_prompt = seq_group.encoder_prompt prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids encoder_prompt = seq_group.encoder_prompt
prompt_logprobs = seq_group.prompt_logprobs encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
finished = seq_group.is_finished() prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time) seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id, return cls(seq_group.request_id,
......
import logging import logging
from typing import Callable, Optional, Union
import vllm.envs as envs import vllm.envs as envs
...@@ -29,3 +30,15 @@ def load_general_plugins(): ...@@ -29,3 +30,15 @@ def load_general_plugins():
except Exception: except Exception:
logger.exception("Failed to load general plugin: %s", logger.exception("Failed to load general plugin: %s",
plugin.name) plugin.name)
_torch_compile_backend: Optional[Union[Callable, str]] = None
def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
from enum import IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
...@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits ...@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from.""" to sample from."""
class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams( class SamplingParams(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
...@@ -147,6 +156,7 @@ class SamplingParams( ...@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors: Optional[Any] = None logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
# The below fields are not supposed to be used as an input. # The below fields are not supposed to be used as an input.
# They are set in post_init. # They are set in post_init.
...@@ -182,6 +192,7 @@ class SamplingParams( ...@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors: Optional[List[LogitsProcessor]] = None, logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None, msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
) -> "SamplingParams": ) -> "SamplingParams":
return SamplingParams( return SamplingParams(
n=1 if n is None else n, n=1 if n is None else n,
...@@ -213,6 +224,7 @@ class SamplingParams( ...@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens=spaces_between_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens,
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind,
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -317,6 +329,9 @@ class SamplingParams( ...@@ -317,6 +329,9 @@ class SamplingParams(
raise ValueError( raise ValueError(
"stop strings are only supported when detokenize is True. " "stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.") "Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None: def _verify_beam_search(self) -> None:
if self.best_of == 1: if self.best_of == 1:
......
...@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
Optional, Set, Tuple, Union, cast) from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
import msgspec import msgspec
import torch import torch
...@@ -407,6 +408,10 @@ class Sequence: ...@@ -407,6 +408,10 @@ class Sequence:
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_text_offset: int = 0
# Used for incremental detokenization # Used for incremental detokenization
self.prefix_offset = 0 self.prefix_offset = 0
self.read_offset = 0 self.read_offset = 0
...@@ -462,11 +467,37 @@ class Sequence: ...@@ -462,11 +467,37 @@ class Sequence:
return self.prompt_adapter_request.prompt_adapter_id \ return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0 if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int): def get_output_text_to_return(self, buffer_length: int,
delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished. # We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished() truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else ( if not delta:
self.output_text) return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
return self.output_text[last_offset:length]
return ""
def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()
def hash_of_block(self, logical_idx: int) -> int: def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size # TODO This can produce incorrect hash when block size > prompt size
......
...@@ -4,7 +4,9 @@ import json ...@@ -4,7 +4,9 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Type, Union from typing import Any, Dict, Optional, Type, Union
from huggingface_hub import file_exists, hf_hub_download import huggingface_hub
from huggingface_hub import (file_exists, hf_hub_download,
try_to_load_from_cache)
from transformers import GenerationConfig, PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import ( from transformers.models.auto.image_processing_auto import (
get_image_processor_config) get_image_processor_config)
...@@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision, ...@@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if Path(model).exists(): if Path(model).exists():
return (Path(model) / config_name).is_file() return (Path(model) / config_name).is_file()
return file_exists(model, config_name, revision=revision, token=token) # Offline mode support: Check if config file is cached already
cached_filepath = try_to_load_from_cache(repo_id=model,
filename=config_name,
revision=revision)
if isinstance(cached_filepath, str):
# The config file exists in cache- we can continue trying to load
return True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
try:
return file_exists(model, config_name, revision=revision, token=token)
except huggingface_hub.errors.OfflineModeIsEnabled:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
return False
def get_config( def get_config(
...@@ -102,6 +119,15 @@ def get_config( ...@@ -102,6 +119,15 @@ def get_config(
token=kwargs.get("token")): token=kwargs.get("token")):
config_format = ConfigFormat.MISTRAL config_format = ConfigFormat.MISTRAL
else: else:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists(model,
HF_CONFIG_NAME,
revision=revision,
token=kwargs.get("token"))
raise ValueError(f"No supported config format found in {model}") raise ValueError(f"No supported config format found in {model}")
if config_format == ConfigFormat.HF: if config_format == ConfigFormat.HF:
...@@ -206,6 +232,8 @@ def load_params_config(model, revision) -> PretrainedConfig: ...@@ -206,6 +232,8 @@ def load_params_config(model, revision) -> PretrainedConfig:
config_dict["tie_word_embeddings"] = config_dict.get( config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False) "tie_embeddings", False)
config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000)
config_dict["max_position_embeddings"] = config_dict.get(
"max_position_embeddings", 128_000)
if config_dict.get("moe") is not None: if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"] config_dict["architectures"] = ["MixtralForCausalLM"]
......
...@@ -82,6 +82,9 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " ...@@ -82,6 +82,9 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/" "currently supported with encoder/"
"decoder models.") "decoder models.")
STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
"encoder/decoder models.")
# Efficiently import all enc/dec error strings # Efficiently import all enc/dec error strings
# rather than having to import all of the above # rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS = { STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
...@@ -97,6 +100,7 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = { ...@@ -97,6 +100,7 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
"STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
} }
# Constants related to forcing the attention backend selection # Constants related to forcing the attention backend selection
......
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
try: try:
import vllm.commit_id import vllm.commit_id
__commit__ = vllm.commit_id.__commit__ __commit__ = vllm.commit_id.__commit__
except Exception as e: except Exception as e:
warnings.warn(f"Failed to read commit hash:\n{e}", warnings.warn(f"Failed to read commit hash:\n{e}",
...@@ -9,4 +10,4 @@ except Exception as e: ...@@ -9,4 +10,4 @@ except Exception as e:
stacklevel=2) stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER" __commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.6.1" __version__ = "0.6.1.post2"
...@@ -15,7 +15,7 @@ from vllm.model_executor.model_loader import get_model ...@@ -15,7 +15,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs) MultiModalInputs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
...@@ -121,6 +121,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]): ...@@ -121,6 +121,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# Lazy initialization. # Lazy initialization.
self.model: nn.Module # Set after init_Model self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
......
...@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import ( ...@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict) _init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
...@@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!") "This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile( self.model = torch.compile(
self.model, self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager") backend=backend)
def save_sharded_state( def save_sharded_state(
self, self,
...@@ -1489,6 +1491,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1489,6 +1491,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine=virtual_engine) virtual_engine=virtual_engine)
@torch.inference_mode() @torch.inference_mode()
@dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
def execute_model( def execute_model(
self, self,
model_input: ModelInputForGPUWithSamplingMetadata, model_input: ModelInputForGPUWithSamplingMetadata,
......
import dataclasses import dataclasses
import pickle
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime
from functools import wraps
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar) TypeVar)
...@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict( ...@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
return tensor_dict return tensor_dict
def dump_input_when_exception(exclude_args: Optional[List[int]] = None,
exclude_kwargs: Optional[List[str]] = None):
def _inner(func):
@wraps(func)
def _wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as err:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl"
with open(filename, "wb") as filep:
dumped_inputs = {
k: v
for k, v in kwargs.items()
if k not in (exclude_kwargs or [])
}
for i, arg in enumerate(args):
if i not in (exclude_args or []):
dumped_inputs[f"arg_{i}"] = arg
pickle.dump(dumped_inputs, filep)
raise type(err)(
f"Error in model execution (input dumped to {filename}): "
f"{str(err)}") from err
return _wrapper
return _inner
class BroadcastableModelInput(ABC): class BroadcastableModelInput(ABC):
@abstractmethod @abstractmethod
......
...@@ -4,13 +4,6 @@ from dataclasses import dataclass, field ...@@ -4,13 +4,6 @@ from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union) Union)
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
import torch import torch
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
...@@ -36,6 +29,8 @@ if TYPE_CHECKING: ...@@ -36,6 +29,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
def seq_output_builder(): def seq_output_builder():
return SequenceOutput( return SequenceOutput(
...@@ -230,12 +225,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -230,12 +225,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self._base_model_runner: GPUModelRunnerBase = base_model_runner self._base_model_runner: GPUModelRunnerBase = base_model_runner
self.is_multi_step = self.scheduler_config.is_multi_step self.is_multi_step = self.scheduler_config.is_multi_step
# used to copy tensors from GPU to CPU asynchronously
self._copy_stream = torch.cuda.Stream()
self.pinned_sampled_token_ids: Optional[torch.Tensor] = None self.pinned_sampled_token_ids: Optional[torch.Tensor] = None
self.pythonization_cache = PythonizationCache() self.pythonization_cache = PythonizationCache()
@functools.cached_property
def _copy_stream(self):
# used to copy tensors from GPU to CPU asynchronously
return torch.cuda.Stream()
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
model_input = (StatefulModelInput.from_broadcasted_tensor_dict( model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
...@@ -486,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -486,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def _advance_step(self, model_input: StatefulModelInput, def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput: out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS:
assert frozen_model_input is not None raise ValueError(
assert frozen_model_input.attn_metadata is not None f"Multi-step not supported for attention backend: "
f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.")
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
num_seqs = model_input.num_seqs num_seqs = model_input.num_seqs
num_queries = model_input.num_queries num_queries = model_input.num_queries
assert num_seqs > 0 frozen_model_input = model_input.frozen_model_input
assert num_queries > 0 assert frozen_model_input is not None
assert num_seqs >= num_queries
attn_metadata = frozen_model_input.attn_metadata attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata) assert attn_metadata is not None
attn_metadata.advance_step( attn_metadata.advance_step(
frozen_model_input, frozen_model_input,
model_input.cached_outputs[-1].sampled_token_ids, self.block_size, sampled_token_ids,
num_seqs, num_queries) self.block_size,
num_seqs,
if frozen_model_input.seq_lens is not None: num_queries,
for i in range(num_queries): )
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]
return model_input return model_input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment