Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
......@@ -42,13 +42,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
......
......@@ -45,12 +45,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers
......
......@@ -3,18 +3,16 @@ within a vision language model."""
import math
from array import array
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from PIL import Image
from torch import nn
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -28,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible
......@@ -93,7 +97,7 @@ def input_processor_for_siglip(
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
......@@ -221,9 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):
class SiglipParallelAttention(nn.Module):
def __init__(
self,
......@@ -233,38 +235,30 @@ class SiglipTPAttention(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({self.total_num_heads}) "
"must be divisible by the tensor model parallel size"
f" ({tp_size}).")
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.embed_dim // self.total_num_heads
if self.head_dim * self.total_num_heads != self.embed_dim:
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).")
self.qkv_size = self.num_heads * self.head_dim
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_heads=self.num_heads,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
)
self.attn_fn = self._basic_attention_forward
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward(
self,
......@@ -274,161 +268,27 @@ class SiglipTPAttention(nn.Module):
batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.split(
[self.qkv_size] * 3, dim=-1)
attn_output = self.attn_fn(
q=query_states,
k=key_states,
v=value_states,
batch_size=batch_size,
q_len=q_len,
)
attn_output, _ = self.out_proj(attn_output)
return attn_output
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = k.shape[-2]
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
if attn_weights.size() != (
batch_size,
self.num_heads,
q_len,
k_v_seq_len,
):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (
batch_size,
self.num_heads,
q_len,
self.head_dim,
):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout,
causal=False,
)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.attn_fn = self._sdpa_attention_forward
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
query_states = query_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
key_states = key_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)
out = xops.memory_efficient_attention_forward(query_states,
key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._xformers_attention_forward
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = memory_efficient_attention(q,
k,
v,
p=0.0,
scale=self.scale)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipTPAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
return attn_output, None
class SiglipMLP(nn.Module):
......@@ -473,8 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super().__init__()
self.embed_dim = config.hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
self.self_attn = SiglipAttention(config)
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
......@@ -577,14 +443,27 @@ class SiglipVisionTransformer(nn.Module):
self.config = config
embed_dim = config.hidden_size
if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
if self.need_post_layernorm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
......@@ -604,7 +483,6 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
......@@ -623,12 +501,20 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
@property
def need_post_layernorm(self):
return self.vision_model.need_post_layernorm
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
......@@ -647,6 +533,11 @@ class SiglipVisionModel(nn.Module):
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
......
......@@ -36,12 +36,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
class StablelmMLP(nn.Module):
......
......@@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
class Starcoder2Attention(nn.Module):
......
......@@ -8,7 +8,6 @@ from functools import lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast)
import librosa
import numpy as np
import torch
import torch.utils.checkpoint
......@@ -27,17 +26,18 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights,
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN = 128002
......@@ -48,13 +48,14 @@ logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, 80, M)"""
data: NestedTensors
"""Shape: `(batch_size, num_audios, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
data: torch.Tensor
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
......@@ -85,27 +86,41 @@ def dummy_data_for_ultravox(
audio_count = mm_counts["audio"]
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [
_AUDIO_PLACEHOLDER_TOKEN
]) * get_ultravox_max_audio_tokens(ctx) * audio_count
audio_placeholder = array(
VLLM_TOKEN_ID_ARRAY_TYPE,
[_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
# Add a separator between each chunk.
audio_token_ids = (audio_placeholder +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids))
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = {
"audio":
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
}
mm_dict = {"audio": [audio_and_sr] * audio_count}
return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object):
if isinstance(data, tuple):
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)
if not isinstance(data, list):
data = [data]
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx)
if sr != feature_extractor.sampling_rate:
try:
import librosa
except ImportError:
raise ImportError(
"Please install vllm[audio] for audio support.") from None
audio = librosa.resample(audio,
orig_sr=sr,
target_sr=feature_extractor.sampling_rate)
......@@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
# Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
return MultiModalInputs({
"audio_features":
feature_extractor(audio,
sampling_rate=sr,
padding="longest",
return_tensors="pt")["input_features"]
})
single_audio_features = feature_extractor(
audio, sampling_rate=sr, padding="longest",
return_tensors="pt")["input_features"]
raise NotImplementedError(f"Unsupported data type: {type(data)}")
# Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))
return MultiModalInputs({"audio_features": audio_features})
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
......@@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
feature_extractor = whisper_feature_extractor(ctx)
audio_data, sample_rate = multi_modal_data["audio"]
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length -
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audios = [audios]
audio_token_counts = []
for audio_data, sample_rate in audios:
audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling.
adjustment = feature_extractor.sampling_rate / sample_rate
audio_length = math.ceil(adjustment * audio_length)
feature_extractor_output_length = math.ceil(
(audio_length - (feature_extractor.hop_length - 1)) /
feature_extractor.hop_length)
uv_config = ctx.get_hf_config(UltravoxConfig)
audio_num_tokens = min(
max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
......@@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_num_tokens,
repeat_count=audio_token_counts,
)
# NOTE: Create a defensive copy of the original inputs
......@@ -337,7 +357,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
data=audio_features)
if audio_embeds is not None:
if not isinstance(audio_embeds, torch.Tensor):
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
......@@ -347,22 +367,38 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.")
def _process_audio_input(
self, audio_input: UltravoxAudioInputs
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, audio_input: UltravoxAudioInputs) -> NestedTensors:
if audio_input["type"] == "audio_embeds":
return audio_input["data"]
audio_features = audio_input["data"]
if isinstance(audio_features, list):
# TODO: Batch these through the encoder/projector instead of
# serializing them.
return [
self._audio_features_to_embeddings(
features.unsqueeze(0)).squeeze(0)
for features in audio_features
]
else:
return self._audio_features_to_embeddings(audio_features)
if isinstance(audio_features, torch.Tensor):
# Combine the B and N dimensions for the encoder/projector
flattened = flatten_bn(audio_features)
flattened_embeddings = self._audio_features_to_embeddings(
flattened)
# Restore the original dimensions
embeddings = flattened_embeddings.unflatten(
0, audio_features.shape[:2])
return embeddings
result = []
# TODO: Batch heterogeneous tensors through the encoder/projector
for audio_features_item in audio_features:
if isinstance(audio_features_item, torch.Tensor):
result.append(
self._audio_features_to_embeddings(audio_features_item))
else:
embeddings = [
# Add a batch dimension to embed it, then remove it.
self._audio_features_to_embeddings(tensor.unsqueeze(0)
).squeeze(0)
for tensor in audio_features_item
]
result.append(embeddings)
return result
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
......@@ -379,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
with the `input_ids`.
Args:
input_features: A batch of audio inputs, [1, 80, M].
audio_features: A batch of audio inputs [B, N, 80, M].
"""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
......
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
import torch
import torch.nn as nn
......@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available
......@@ -54,9 +55,73 @@ def init_vllm_registered_model(
)
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors,
multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
......@@ -67,30 +132,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum()
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape
total_tokens = batch_size * batch_tokens
if num_expected_tokens != total_tokens:
expr = f"{batch_size} x {batch_tokens}"
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = flattened
return inputs_embeds
......
......@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
......
from fractions import Fraction
from typing import Callable, Optional, Union
import torch
......@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
"""
def __init__(self,
packed_factor: int,
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
......@@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter):
"""
def __init__(self,
packed_factor: int,
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
......
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
from .registry import MultiModalRegistry
......@@ -14,7 +14,6 @@ See also:
__all__ = [
"BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalInputs",
......
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
TypedDict, TypeVar, Union, cast, final)
import numpy as np
import torch
......@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import JSONTree, json_map_leaves
from vllm.utils import JSONTree, is_list_of, json_map_leaves
logger = init_logger(__name__)
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
......@@ -54,26 +46,24 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@staticmethod
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
Recursively stacks lists of tensors when they all have the same shape.
"""
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
tensors_ = cast(List[torch.Tensor], tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
unbatched_shape = tensors_[0].shape[1:]
tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
for tensor in tensors_:
if tensor.shape[1:] != unbatched_shape:
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0)
return torch.stack(tensors_)
@staticmethod
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
......@@ -102,7 +92,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists[k].append(v)
return {
k: MultiModalInputs._try_concat(item_list)
k: MultiModalInputs._try_stack(item_list)
for k, item_list in item_lists.items()
}
......@@ -112,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
batched_inputs)
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
_T = TypeVar("_T")
......
import base64
from functools import lru_cache
from io import BytesIO
from typing import List, Optional, Tuple, TypeVar, Union
from typing import Any, List, Optional, Tuple, TypeVar, Union
import librosa
import numpy as np
import soundfile
from PIL import Image
from vllm.connections import global_http_connection
......@@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str,
return image.convert(image_mode)
def try_import_audio_packages() -> Tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError:
raise ImportError(
"Please install vllm[audio] for audio support.") from None
return librosa, soundfile
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
......@@ -95,6 +105,8 @@ async def async_fetch_audio(
"""
Asynchronously fetch audio from a URL.
"""
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
......@@ -108,6 +120,16 @@ async def async_fetch_audio(
return librosa.load(BytesIO(audio_bytes), sr=None)
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = fetch_audio(audio_url)
return {"audio": (audio, sr)}
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
return {"image": image}
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
......@@ -123,6 +145,8 @@ def encode_audio_base64(
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
_, soundfile = try_import_audio_packages()
buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV")
......@@ -189,10 +213,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids: List[int],
*,
placeholder_token_id: int,
repeat_count: int = 1,
repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None:
new_prompt = None
else:
......@@ -201,13 +228,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
......@@ -216,28 +236,45 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
elif placeholder_token_count > 1:
logger.warning("Multiple multi-modal input is not supported yet, "
"so any extra placeholder tokens will be treated "
"as plain text.")
# The image tokens are removed to be consistent with HuggingFace
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
if placeholder_token_count < len(repeat_count):
logger.warning(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = []
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=repeat_count,
repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + 1:])
break
# No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
......
......@@ -21,7 +21,9 @@ _R = TypeVar("_R")
if pynvml.__file__.endswith("__init__.py"):
logger.warning(
"You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead. See https://pypi.org/project/pynvml "
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.")
# NVML utils
......@@ -82,6 +84,9 @@ except ModuleNotFoundError:
def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string,"
" which means GPU support is disabled.")
physical_device_id = device_ids[device_id]
return int(physical_device_id)
else:
......
import os
from functools import lru_cache
from typing import Tuple
import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
logger = init_logger(__name__)
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
......
......@@ -125,6 +125,15 @@ def main():
serve_parser.add_argument("model_tag",
type=str,
help="The model tag to serve")
serve_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config file."
"Must be a YAML with the following options:"
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server"
)
serve_parser = make_arg_parser(serve_parser)
serve_parser.set_defaults(dispatch_function=serve)
......
......@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Union, cast)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Optional, Set, Tuple, Union, cast)
import msgspec
import torch
......@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
......@@ -814,6 +811,9 @@ class SequenceGroup:
self.is_single_seq = len(self.seqs) == 1
def is_finished(self) -> bool:
if self.is_single_seq:
return self.seqs[0].is_finished()
return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool:
......@@ -886,7 +886,7 @@ class SequenceGroupMetadata(
request_id: str
is_prompt: bool
seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]]
do_sample: bool = True
pooling_params: Optional[PoolingParams] = None
......@@ -1060,76 +1060,6 @@ class IntermediateTensors(
return f"IntermediateTensors(tensors={self.tensors})"
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
......@@ -1293,6 +1223,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
@property
def is_first_multi_step(self) -> bool:
......@@ -1338,4 +1270,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)
if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
......@@ -5,13 +5,13 @@ from typing import Iterator, List, Optional, Tuple
import torch
from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
SamplerOutput, SequenceData, SequenceGroupMetadata,
SequenceData, SequenceGroupMetadata,
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
split_batch_by_proposal_len)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
from vllm.worker.worker_base import WorkerBase
SeqId = int
......@@ -88,17 +88,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
(all_tokens, all_probs, spec_logprobs,
all_hidden_states) = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
if not non_spec_indices:
# All sequence groups in batch have spec decoding enabled
contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
contracted = self._contract_batch(
contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
......@@ -121,14 +129,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
split_batch_by_proposal_len(
seq_group_metadata_list, proposal_lens_list)
target_seq_group_metadata_list = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs,
......@@ -171,7 +174,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
......@@ -181,7 +184,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
*target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
......@@ -196,24 +199,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states = None
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids.unsqueeze(1)
all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None:
all_hidden_states[
non_spec_indices, :1, :] = non_spec_target_hidden_states
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices:
all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
def _create_scoring_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
......@@ -345,8 +382,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size=1,
)
@staticmethod
def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int
sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
......@@ -361,10 +399,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
#
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
split_sizes = [
num_scoring_tokens,
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
]
split_sizes = (num_scoring_tokens,
sampler_output.sampled_token_ids.numel() -
num_scoring_tokens)
(spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens
......@@ -382,32 +419,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else:
spec_hidden_states, non_spec_hidden_states = None, None
# Convert scores to tensors.
sampler_output.sampled_token_probs = spec_probs
sampler_output.sampled_token_ids = spec_sampled_tokens
sampler_output.logprobs = spec_logprobs
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)
return (spec_sampled_tokens, spec_probs, spec_logprobs,
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
non_spec_logprobs, non_spec_hidden_states)
@staticmethod
def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
......@@ -417,8 +435,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""
return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of
......@@ -439,8 +457,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids))
])
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
for i in range(len(full_spec_token_ids)))
return token_ids_to_score
......@@ -3,6 +3,7 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput
try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
......@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
......
......@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
......
......@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment