Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -42,13 +42,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -42,13 +42,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
......
...@@ -45,12 +45,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -45,12 +45,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers from .utils import is_pp_missing_parameter, make_layers
......
...@@ -3,18 +3,16 @@ within a vision language model.""" ...@@ -3,18 +3,16 @@ within a vision language model."""
import math import math
from array import array from array import array
from typing import Iterable, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -28,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer, ...@@ -28,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
try:
from xformers import ops as xops
USE_XFORMERS_OPS = True
except ImportError:
USE_XFORMERS_OPS = False
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
# Since interpolation is applied, the image size need not be divisible # Since interpolation is applied, the image size need not be divisible
...@@ -93,7 +97,7 @@ def input_processor_for_siglip( ...@@ -93,7 +97,7 @@ def input_processor_for_siglip(
llm_inputs: LLMInputs, llm_inputs: LLMInputs,
*, *,
image_token_id: int, image_token_id: int,
image_feature_size_override: Optional[int] = None, image_feature_size_override: Optional[Union[int, List[int]]] = None,
): ):
multi_modal_data = llm_inputs.get("multi_modal_data") multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data: if multi_modal_data is None or "image" not in multi_modal_data:
...@@ -221,9 +225,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -221,9 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return embeddings return embeddings
# NOTE: Not used - kept for later when we TP the ViT class SiglipParallelAttention(nn.Module):
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):
def __init__( def __init__(
self, self,
...@@ -233,38 +235,30 @@ class SiglipTPAttention(nn.Module): ...@@ -233,38 +235,30 @@ class SiglipTPAttention(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size() self.head_dim = self.embed_dim // self.num_heads
self.total_num_heads = config.num_attention_heads if self.head_dim * self.num_heads != self.embed_dim:
if self.total_num_heads % tp_size != 0:
raise ValueError(
f"Number of attention heads ({self.total_num_heads}) "
"must be divisible by the tensor model parallel size"
f" ({tp_size}).")
self.num_heads = self.total_num_heads // tp_size
self.head_dim = self.embed_dim // self.total_num_heads
if self.head_dim * self.total_num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got " raise ValueError(f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:" "`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads}).") f" {self.num_heads}).")
self.qkv_size = self.num_heads * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim, hidden_size=self.embed_dim,
head_size=self.head_dim, head_size=self.head_dim,
total_num_heads=self.total_num_heads, total_num_heads=self.num_heads,
quant_config=quant_config, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
input_size=self.embed_dim, input_size=self.embed_dim,
output_size=self.embed_dim, output_size=self.embed_dim,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn_fn = self._basic_attention_forward self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
def forward( def forward(
self, self,
...@@ -274,161 +268,27 @@ class SiglipTPAttention(nn.Module): ...@@ -274,161 +268,27 @@ class SiglipTPAttention(nn.Module):
batch_size, q_len, _ = hidden_states.size() batch_size, q_len, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
query_states, key_states, value_states = qkv_states.split( query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
[self.qkv_size] * 3, dim=-1)
attn_output = self.attn_fn(
q=query_states,
k=key_states,
v=value_states,
batch_size=batch_size,
q_len=q_len,
)
attn_output, _ = self.out_proj(attn_output)
return attn_output
def _basic_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
k_v_seq_len = k.shape[-2]
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
if attn_weights.size() != (
batch_size,
self.num_heads,
q_len,
k_v_seq_len,
):
raise ValueError(
"Attention weights should be of size "
f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}")
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights,
p=self.dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, v)
if attn_output.size() != (
batch_size,
self.num_heads,
q_len,
self.head_dim,
):
raise ValueError(
"`attn_output` should be of size "
f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
**kwargs):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = flash_attn_func(
q,
k,
v,
dropout_p=self.dropout,
causal=False,
)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
self.attn_fn = self._sdpa_attention_forward
def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): query_states = query_states.view(batch_size, q_len,
q = q.view(batch_size, q_len, self.num_heads, self.num_heads_per_partition,
self.head_dim).transpose(1, 2) self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, key_states = key_states.view(batch_size, q_len,
self.head_dim).transpose(1, 2) self.num_heads_per_partition,
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, q_len,
self.num_heads_per_partition,
self.head_dim)
attn_output = torch.nn.functional.scaled_dot_product_attention( out = xops.memory_efficient_attention_forward(query_states,
q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) key_states,
value_states,
p=self.dropout,
scale=self.scale)
out = out.view(batch_size, q_len, -1)
attn_output, _ = self.out_proj(out)
attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, None
attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.attn_fn = self._xformers_attention_forward
def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
v = v.view(batch_size, q_len, self.num_heads, self.head_dim)
attn_output = memory_efficient_attention(q,
k,
v,
p=0.0,
scale=self.scale)
attn_output = attn_output.reshape(batch_size, q_len,
self.embed_dim).contiguous()
return attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipTPAttention,
"flash_attention_2": SiglipFlashAttention2,
"sdpa": SiglipSdpaAttention,
"xformers": SiglipxFormersAttention,
}
class SiglipMLP(nn.Module): class SiglipMLP(nn.Module):
...@@ -473,8 +333,14 @@ class SiglipEncoderLayer(nn.Module): ...@@ -473,8 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# TODO(ChristopherCho): use TP'ed Attention block num_heads = config.num_attention_heads
self.self_attn = SiglipAttention(config) tp_size = get_tensor_model_parallel_world_size()
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
self.self_attn = SiglipParallelAttention(config,
quant_config=quant_config)
else:
self.self_attn = SiglipSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.mlp = SiglipMLP( self.mlp = SiglipMLP(
...@@ -577,14 +443,27 @@ class SiglipVisionTransformer(nn.Module): ...@@ -577,14 +443,27 @@ class SiglipVisionTransformer(nn.Module):
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False
self.embeddings = SiglipVisionEmbeddings(config) self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder( self.encoder = SiglipEncoder(
config, config,
quant_config=quant_config, quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
self.post_layernorm = nn.LayerNorm(embed_dim, if self.need_post_layernorm:
eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
self.use_head = (True if not hasattr(config, "vision_use_head") else self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head) config.vision_use_head)
if self.use_head: if self.use_head:
...@@ -604,7 +483,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -604,7 +483,6 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs = self.encoder(inputs_embeds=hidden_states) encoder_outputs = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference # TODO: add this back when pooled_output is used in inference
# if self.use_head: # if self.use_head:
# pooled_output = self.head(last_hidden_state) # pooled_output = self.head(last_hidden_state)
...@@ -623,12 +501,20 @@ class SiglipVisionModel(nn.Module): ...@@ -623,12 +501,20 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
): ):
super().__init__() super().__init__()
num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
self.vision_model = SiglipVisionTransformer( self.vision_model = SiglipVisionTransformer(
config, config,
quant_config, quant_config,
num_hidden_layers_override=num_hidden_layers_override, num_hidden_layers_override=num_hidden_layers_override,
) )
@property
def need_post_layernorm(self):
return self.vision_model.need_post_layernorm
def get_input_embeddings(self) -> nn.Module: def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding return self.vision_model.embeddings.patch_embedding
...@@ -647,6 +533,11 @@ class SiglipVisionModel(nn.Module): ...@@ -647,6 +533,11 @@ class SiglipVisionModel(nn.Module):
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
continue
# omit layers when num_hidden_layers_override is set # omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name: if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3]) layer_idx = int(name.split(".")[3])
......
...@@ -36,12 +36,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -36,12 +36,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
......
...@@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
......
...@@ -8,7 +8,6 @@ from functools import lru_cache ...@@ -8,7 +8,6 @@ from functools import lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
import librosa
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -27,17 +26,18 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn ...@@ -27,17 +26,18 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
...@@ -48,13 +48,14 @@ logger = init_logger(__name__) ...@@ -48,13 +48,14 @@ logger = init_logger(__name__)
class UltravoxAudioFeatureInputs(TypedDict): class UltravoxAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]] data: NestedTensors
"""Shape: `(batch_size, 80, M)""" """Shape: `(batch_size, num_audios, 80, M)"""
class UltravoxAudioEmbeddingInputs(TypedDict): class UltravoxAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
data: torch.Tensor data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
...@@ -85,27 +86,41 @@ def dummy_data_for_ultravox( ...@@ -85,27 +86,41 @@ def dummy_data_for_ultravox(
audio_count = mm_counts["audio"] audio_count = mm_counts["audio"]
audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [ audio_placeholder = array(
_AUDIO_PLACEHOLDER_TOKEN VLLM_TOKEN_ID_ARRAY_TYPE,
]) * get_ultravox_max_audio_tokens(ctx) * audio_count [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx)
# Add a separator between each chunk.
audio_token_ids = (audio_placeholder +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count
other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - len(audio_token_ids)) [0]) * (seq_len - len(audio_token_ids))
audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1)
mm_dict = { mm_dict = {"audio": [audio_and_sr] * audio_count}
"audio":
audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count
}
return (SequenceData(audio_token_ids + other_token_ids), mm_dict) return (SequenceData(audio_token_ids + other_token_ids), mm_dict)
def input_mapper_for_ultravox(ctx: InputContext, data: object): def input_mapper_for_ultravox(ctx: InputContext, data: object):
if isinstance(data, tuple): if not isinstance(data, list):
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) data = [data]
audio_features = []
for audio_input in data:
if not isinstance(audio_input, tuple):
raise NotImplementedError(
f"Unsupported data type: {type(audio_input)}")
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input)
feature_extractor = whisper_feature_extractor(ctx) feature_extractor = whisper_feature_extractor(ctx)
if sr != feature_extractor.sampling_rate: if sr != feature_extractor.sampling_rate:
try:
import librosa
except ImportError:
raise ImportError(
"Please install vllm[audio] for audio support.") from None
audio = librosa.resample(audio, audio = librosa.resample(audio,
orig_sr=sr, orig_sr=sr,
target_sr=feature_extractor.sampling_rate) target_sr=feature_extractor.sampling_rate)
...@@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): ...@@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
# Not enough audio; pad it. # Not enough audio; pad it.
audio = np.pad(audio, (0, minimum_audio_length - len(audio))) audio = np.pad(audio, (0, minimum_audio_length - len(audio)))
return MultiModalInputs({ single_audio_features = feature_extractor(
"audio_features": audio, sampling_rate=sr, padding="longest",
feature_extractor(audio, return_tensors="pt")["input_features"]
sampling_rate=sr,
padding="longest",
return_tensors="pt")["input_features"]
})
raise NotImplementedError(f"Unsupported data type: {type(data)}") # Remove the batch dimension because we're wrapping it in a list.
audio_features.append(single_audio_features.squeeze(0))
return MultiModalInputs({"audio_features": audio_features})
def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
...@@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs return llm_inputs
feature_extractor = whisper_feature_extractor(ctx) feature_extractor = whisper_feature_extractor(ctx)
audio_data, sample_rate = multi_modal_data["audio"] audios = multi_modal_data["audio"]
if not isinstance(audios, list):
audio_length = audio_data.shape[0] audios = [audios]
if sample_rate != feature_extractor.sampling_rate:
# Account for resampling. audio_token_counts = []
adjustment = feature_extractor.sampling_rate / sample_rate for audio_data, sample_rate in audios:
audio_length = math.ceil(adjustment * audio_length) audio_length = audio_data.shape[0]
if sample_rate != feature_extractor.sampling_rate:
feature_extractor_output_length = math.ceil( # Account for resampling.
(audio_length - adjustment = feature_extractor.sampling_rate / sample_rate
(feature_extractor.hop_length - 1)) / feature_extractor.hop_length) audio_length = math.ceil(adjustment * audio_length)
uv_config = ctx.get_hf_config(UltravoxConfig) feature_extractor_output_length = math.ceil(
audio_num_tokens = min( (audio_length - (feature_extractor.hop_length - 1)) /
max( feature_extractor.hop_length)
1,
math.ceil(feature_extractor_output_length / uv_config = ctx.get_hf_config(UltravoxConfig)
(uv_config.stack_factor * 2))), audio_num_tokens = min(
get_ultravox_max_audio_tokens(ctx)) max(
1,
math.ceil(feature_extractor_output_length /
(uv_config.stack_factor * 2))),
get_ultravox_max_audio_tokens(ctx))
audio_token_counts.append(audio_num_tokens)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
...@@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs.get("prompt"), llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"], llm_inputs["prompt_token_ids"],
placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
repeat_count=audio_num_tokens, repeat_count=audio_token_counts,
) )
# NOTE: Create a defensive copy of the original inputs # NOTE: Create a defensive copy of the original inputs
...@@ -337,7 +357,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -337,7 +357,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
data=audio_features) data=audio_features)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, torch.Tensor): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. " raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}") f"Got type: {type(audio_embeds)}")
...@@ -347,22 +367,38 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -347,22 +367,38 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_audio_input( def _process_audio_input(
self, audio_input: UltravoxAudioInputs self, audio_input: UltravoxAudioInputs) -> NestedTensors:
) -> Union[torch.Tensor, List[torch.Tensor]]:
if audio_input["type"] == "audio_embeds": if audio_input["type"] == "audio_embeds":
return audio_input["data"] return audio_input["data"]
audio_features = audio_input["data"] audio_features = audio_input["data"]
if isinstance(audio_features, list): if isinstance(audio_features, torch.Tensor):
# TODO: Batch these through the encoder/projector instead of # Combine the B and N dimensions for the encoder/projector
# serializing them. flattened = flatten_bn(audio_features)
return [ flattened_embeddings = self._audio_features_to_embeddings(
self._audio_features_to_embeddings( flattened)
features.unsqueeze(0)).squeeze(0)
for features in audio_features # Restore the original dimensions
] embeddings = flattened_embeddings.unflatten(
else: 0, audio_features.shape[:2])
return self._audio_features_to_embeddings(audio_features) return embeddings
result = []
# TODO: Batch heterogeneous tensors through the encoder/projector
for audio_features_item in audio_features:
if isinstance(audio_features_item, torch.Tensor):
result.append(
self._audio_features_to_embeddings(audio_features_item))
else:
embeddings = [
# Add a batch dimension to embed it, then remove it.
self._audio_features_to_embeddings(tensor.unsqueeze(0)
).squeeze(0)
for tensor in audio_features_item
]
result.append(embeddings)
return result
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
...@@ -379,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -379,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
with the `input_ids`. with the `input_ids`.
Args: Args:
input_features: A batch of audio inputs, [1, 80, M]. audio_features: A batch of audio inputs [B, N, 80, M].
""" """
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None: if audio_input is not None:
......
from typing import Dict, Iterable, List, Optional, Protocol, Tuple from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, ...@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors from vllm.multimodal.base import NestedTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -54,9 +55,73 @@ def init_vllm_registered_model( ...@@ -54,9 +55,73 @@ def init_vllm_registered_model(
) )
@overload
def flatten_bn(x: torch.Tensor) -> torch.Tensor:
...
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if isinstance(x, torch.Tensor):
return x.flatten(0, 1)
if concat:
return torch.cat(x)
return [x_n for x_b in x for x_n in x_b]
def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if isinstance(embeddings, torch.Tensor):
# Flatten all but the last dimension.
return embeddings.flatten(0, -2)
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
def _embedding_count_expression(embeddings: NestedTensors) -> str:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if isinstance(embeddings, torch.Tensor):
return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
return " + ".join(
_embedding_count_expression(inner) for inner in embeddings)
def merge_multimodal_embeddings(input_ids: torch.Tensor, def merge_multimodal_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: BatchedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: int) -> torch.Tensor: placeholder_token_id: int) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
...@@ -67,30 +132,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, ...@@ -67,30 +132,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
mask = (input_ids == placeholder_token_id) mask = (input_ids == placeholder_token_id)
num_expected_tokens = mask.sum() num_expected_tokens = mask.sum().item()
assert isinstance(num_expected_tokens, int)
if isinstance(multimodal_embeddings, torch.Tensor):
batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape flattened = _flatten_embeddings(multimodal_embeddings)
total_tokens = batch_size * batch_tokens if flattened.shape[0] != num_expected_tokens:
if num_expected_tokens != total_tokens: expr = _embedding_count_expression(multimodal_embeddings)
expr = f"{batch_size} x {batch_tokens}" raise ValueError(
raise ValueError( f"Attempted to assign {expr} = {flattened.shape[0]} "
f"Attempted to assign {expr} = {total_tokens} " f"multimodal tokens to {num_expected_tokens} placeholders")
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = multimodal_embeddings.view(
total_tokens, embed_dim)
else:
size_per_batch = [t.shape[0] for t in multimodal_embeddings]
total_tokens = sum(size_per_batch)
if num_expected_tokens != total_tokens:
expr = ' + '.join(map(str, size_per_batch))
raise ValueError(
f"Attempted to assign {expr} = {total_tokens} "
f"multimodal tokens to {num_expected_tokens} placeholders")
inputs_embeds[mask] = torch.cat(multimodal_embeddings)
inputs_embeds[mask] = flattened
return inputs_embeds return inputs_embeds
......
...@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
......
from fractions import Fraction
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
...@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter): ...@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
""" """
def __init__(self, def __init__(self,
packed_factor: int, packed_factor: Union[int, Fraction],
packed_dim: int, packed_dim: int,
marlin_tile_size: Optional[int] = None, marlin_tile_size: Optional[int] = None,
**kwargs): **kwargs):
...@@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter): ...@@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter):
""" """
def __init__(self, def __init__(self,
packed_factor: int, packed_factor: Union[int, Fraction],
packed_dim: int, packed_dim: int,
marlin_tile_size: Optional[int] = None, marlin_tile_size: Optional[int] = None,
**kwargs): **kwargs):
......
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, from .base import (BatchedTensorInputs, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin, MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors) NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
...@@ -14,7 +14,6 @@ See also: ...@@ -14,7 +14,6 @@ See also:
__all__ = [ __all__ = [
"BatchedTensorInputs", "BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalInputs", "MultiModalInputs",
......
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from typing import Callable, Dict, List, Mapping, Optional from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
from typing import Sequence as GenericSequence TypedDict, TypeVar, Union, cast, final)
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final
import numpy as np import numpy as np
import torch import torch
...@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias ...@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputContext from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import JSONTree, json_map_leaves from vllm.utils import JSONTree, is_list_of, json_map_leaves
logger = init_logger(__name__) logger = init_logger(__name__)
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
""" """
Use a list instead of a tensor if the dimensions of each element do not match. Uses a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
""" """
BatchedTensors: TypeAlias = JSONTree[torch.Tensor] BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
""" """
A dictionary containing nested tensors which have been batched via A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`. :meth:`MultiModalInputs.batch`.
...@@ -54,26 +46,24 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -54,26 +46,24 @@ class MultiModalInputs(_MultiModalInputsBase):
""" """
@staticmethod @staticmethod
def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
""" """
If each input tensor in the batch has the same shape, return a single Recursively stacks lists of tensors when they all have the same shape.
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
""" """
# may be list rather than tensors if isinstance(nested_tensors, torch.Tensor):
if isinstance(tensors[0], list): return nested_tensors
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors) stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
unbatched_shape = tensors_[0].shape[1:] tensors_ = cast(List[torch.Tensor], stacked)
if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
for tensor in tensors_: return torch.stack(tensors_)
if tensor.shape[1:] != unbatched_shape:
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0)
@staticmethod @staticmethod
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
...@@ -102,7 +92,7 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -102,7 +92,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists[k].append(v) item_lists[k].append(v)
return { return {
k: MultiModalInputs._try_concat(item_list) k: MultiModalInputs._try_stack(item_list)
for k, item_list in item_lists.items() for k, item_list in item_lists.items()
} }
...@@ -112,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase): ...@@ -112,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
*, *,
device: torch.types.Device, device: torch.types.Device,
) -> BatchedTensorInputs: ) -> BatchedTensorInputs:
return json_map_leaves(lambda x: x.to(device, non_blocking=True), json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
batched_inputs)
json_mapped = json_map_leaves(
lambda x: x.to(device, non_blocking=True),
json_inputs,
)
return cast(BatchedTensorInputs, json_mapped)
_T = TypeVar("_T") _T = TypeVar("_T")
......
import base64 import base64
from functools import lru_cache from functools import lru_cache
from io import BytesIO from io import BytesIO
from typing import List, Optional, Tuple, TypeVar, Union from typing import Any, List, Optional, Tuple, TypeVar, Union
import librosa
import numpy as np import numpy as np
import soundfile
from PIL import Image from PIL import Image
from vllm.connections import global_http_connection from vllm.connections import global_http_connection
...@@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str, ...@@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str,
return image.convert(image_mode) return image.convert(image_mode)
def try_import_audio_packages() -> Tuple[Any, Any]:
try:
import librosa
import soundfile
except ImportError:
raise ImportError(
"Please install vllm[audio] for audio support.") from None
return librosa, soundfile
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
""" """
Load audio from a URL. Load audio from a URL.
""" """
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"): if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes( audio_bytes = global_http_connection.get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
...@@ -95,6 +105,8 @@ async def async_fetch_audio( ...@@ -95,6 +105,8 @@ async def async_fetch_audio(
""" """
Asynchronously fetch audio from a URL. Asynchronously fetch audio from a URL.
""" """
librosa, _ = try_import_audio_packages()
if audio_url.startswith("http"): if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes( audio_bytes = await global_http_connection.async_get_bytes(
audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT)
...@@ -108,6 +120,16 @@ async def async_fetch_audio( ...@@ -108,6 +120,16 @@ async def async_fetch_audio(
return librosa.load(BytesIO(audio_bytes), sr=None) return librosa.load(BytesIO(audio_bytes), sr=None)
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = fetch_audio(audio_url)
return {"audio": (audio, sr)}
def get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = fetch_image(image_url)
return {"image": image}
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url) audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)} return {"audio": (audio, sr)}
...@@ -123,6 +145,8 @@ def encode_audio_base64( ...@@ -123,6 +145,8 @@ def encode_audio_base64(
sampling_rate: int, sampling_rate: int,
) -> str: ) -> str:
"""Encode audio as base64.""" """Encode audio as base64."""
_, soundfile = try_import_audio_packages()
buffered = BytesIO() buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV") soundfile.write(buffered, audio, sampling_rate, format="WAV")
...@@ -189,10 +213,13 @@ def repeat_and_pad_placeholder_tokens( ...@@ -189,10 +213,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids: List[int], prompt_token_ids: List[int],
*, *,
placeholder_token_id: int, placeholder_token_id: int,
repeat_count: int = 1, repeat_count: Union[int, List[int]],
pad_token_left: Optional[int] = None, pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None, pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int]]: ) -> Tuple[Optional[str], List[int]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None: if prompt is None:
new_prompt = None new_prompt = None
else: else:
...@@ -201,13 +228,6 @@ def repeat_and_pad_placeholder_tokens( ...@@ -201,13 +228,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer.decode(pad_token_left)) tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right)) tokenizer.decode(pad_token_right))
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
placeholder_token_count = prompt.count(placeholder_token_str) placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases # This is an arbitrary number to distinguish between the two cases
...@@ -216,28 +236,45 @@ def repeat_and_pad_placeholder_tokens( ...@@ -216,28 +236,45 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is " "Please follow the prompt format that is "
"documented on HuggingFace which does not involve " "documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str) "repeating %s tokens.", placeholder_token_str)
elif placeholder_token_count > 1: if placeholder_token_count < len(repeat_count):
logger.warning("Multiple multi-modal input is not supported yet, " logger.warning(
"so any extra placeholder tokens will be treated " "The number of multi-modal placeholder tokens in the prompt "
"as plain text.") "is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
# The image tokens are removed to be consistent with HuggingFace repeat_count = repeat_count[:placeholder_token_count]
new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1)
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = [] new_token_ids: List[int] = []
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token( replacement_ids = repeat_and_pad_token(
placeholder_token_id, placeholder_token_id,
repeat_count=repeat_count, repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left, pad_token_left=pad_token_left,
pad_token_right=pad_token_right, pad_token_right=pad_token_right,
) )
new_token_ids.extend(replacement_ids) new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we only replace once # No need to further scan the list since we replaced all tokens
new_token_ids.extend(prompt_token_ids[i + 1:]) if placeholder_token_idx >= len(repeat_count):
break new_token_ids.extend(prompt_token_ids[i + 1:])
break
else: else:
new_token_ids.append(token) new_token_ids.append(token)
......
...@@ -21,7 +21,9 @@ _R = TypeVar("_R") ...@@ -21,7 +21,9 @@ _R = TypeVar("_R")
if pynvml.__file__.endswith("__init__.py"): if pynvml.__file__.endswith("__init__.py"):
logger.warning( logger.warning(
"You are using a deprecated `pynvml` package. Please install" "You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead. See https://pypi.org/project/pynvml " " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.") "for more information.")
# NVML utils # NVML utils
...@@ -82,6 +84,9 @@ except ModuleNotFoundError: ...@@ -82,6 +84,9 @@ except ModuleNotFoundError:
def device_id_to_physical_device_id(device_id: int) -> int: def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ: if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
if device_ids == [""]:
raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string,"
" which means GPU support is disabled.")
physical_device_id = device_ids[device_id] physical_device_id = device_ids[device_id]
return int(physical_device_id) return int(physical_device_id)
else: else:
......
import os
from functools import lru_cache from functools import lru_cache
from typing import Tuple from typing import Tuple
import torch import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum from .interface import Platform, PlatformEnum
logger = init_logger(__name__)
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class RocmPlatform(Platform): class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM _enum = PlatformEnum.ROCM
......
...@@ -125,6 +125,15 @@ def main(): ...@@ -125,6 +125,15 @@ def main():
serve_parser.add_argument("model_tag", serve_parser.add_argument("model_tag",
type=str, type=str,
help="The model tag to serve") help="The model tag to serve")
serve_parser.add_argument(
"--config",
type=str,
default='',
required=False,
help="Read CLI options from a config file."
"Must be a YAML with the following options:"
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server"
)
serve_parser = make_arg_parser(serve_parser) serve_parser = make_arg_parser(serve_parser)
serve_parser.set_defaults(dispatch_function=serve) serve_parser.set_defaults(dispatch_function=serve)
......
...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod ...@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from array import array from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping,
Tuple, Union, cast) Optional, Set, Tuple, Union, cast)
import msgspec import msgspec
import torch import torch
...@@ -474,11 +474,8 @@ class Sequence: ...@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute() self.data.reset_state_for_recompute()
def append_token_id( def append_token_id(self, token_id: int, logprobs: Dict[int,
self, Logprob]) -> None:
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
assert token_id in logprobs assert token_id in logprobs
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob)
...@@ -814,6 +811,9 @@ class SequenceGroup: ...@@ -814,6 +811,9 @@ class SequenceGroup:
self.is_single_seq = len(self.seqs) == 1 self.is_single_seq = len(self.seqs) == 1
def is_finished(self) -> bool: def is_finished(self) -> bool:
if self.is_single_seq:
return self.seqs[0].is_finished()
return all(seq.is_finished() for seq in self.seqs) return all(seq.is_finished() for seq in self.seqs)
def is_prefill(self) -> bool: def is_prefill(self) -> bool:
...@@ -886,7 +886,7 @@ class SequenceGroupMetadata( ...@@ -886,7 +886,7 @@ class SequenceGroupMetadata(
request_id: str request_id: str
is_prompt: bool is_prompt: bool
seq_data: Dict[int, SequenceData] seq_data: Dict[int, SequenceData]
sampling_params: SamplingParams sampling_params: Optional[SamplingParams]
block_tables: Dict[int, List[int]] block_tables: Dict[int, List[int]]
do_sample: bool = True do_sample: bool = True
pooling_params: Optional[PoolingParams] = None pooling_params: Optional[PoolingParams] = None
...@@ -1060,76 +1060,6 @@ class IntermediateTensors( ...@@ -1060,76 +1060,6 @@ class IntermediateTensors(
return f"IntermediateTensors(tensors={self.tensors})" return f"IntermediateTensors(tensors={self.tensors})"
class SamplerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
return self.outputs[idx]
def __setitem__(self, idx: int, value):
self.outputs[idx] = value
def __len__(self):
return len(self.outputs)
def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs
def __repr__(self) -> str:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr = ("None" if self.sampled_token_probs is None
else self.sampled_token_probs.shape)
sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else
self.sampled_token_ids.shape)
return (
f"SamplerOutput(outputs={self.outputs}, "
f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
class PoolerOutput( class PoolerOutput(
msgspec.Struct, msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg] omit_defaults=True, # type: ignore[call-arg]
...@@ -1293,6 +1223,8 @@ class ExecuteModelRequest( ...@@ -1293,6 +1223,8 @@ class ExecuteModelRequest(
finished_requests_ids: List[str] = msgspec.field(default_factory=list) finished_requests_ids: List[str] = msgspec.field(default_factory=list)
# The last sampled token ids for multi step decoding. # The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
@property @property
def is_first_multi_step(self) -> bool: def is_first_multi_step(self) -> bool:
...@@ -1338,4 +1270,5 @@ class ExecuteModelRequest( ...@@ -1338,4 +1270,5 @@ class ExecuteModelRequest(
num_steps=self.num_steps, num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids, finished_requests_ids=self.finished_requests_ids,
last_sampled_token_ids=self.last_sampled_token_ids.clone() last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None) if self.last_sampled_token_ids is not None else None,
async_callback=self.async_callback)
...@@ -5,13 +5,13 @@ from typing import Iterator, List, Optional, Tuple ...@@ -5,13 +5,13 @@ from typing import Iterator, List, Optional, Tuple
import torch import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest,
SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceData, SequenceGroupMetadata,
get_all_seq_ids) get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
split_batch_by_proposal_len)
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
SeqId = int SeqId = int
...@@ -88,17 +88,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -88,17 +88,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
(all_tokens, all_probs, spec_logprobs, if not non_spec_indices:
all_hidden_states) = self._contract_batch( # All sequence groups in batch have spec decoding enabled
contracted_bs=len(execute_model_req.seq_group_metadata_list), contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, )
non_spec_indices=non_spec_indices, else:
spec_indices=spec_indices, # Batch has a mix of spec decode enabled and disabled seq groups
k=execute_model_req.num_lookahead_slots, contracted = self._contract_batch(
) contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices,
spec_indices=spec_indices,
k=execute_model_req.num_lookahead_slots,
)
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
return SpeculativeScores( return SpeculativeScores(
probs=all_probs, probs=all_probs,
token_ids=all_tokens, token_ids=all_tokens,
...@@ -121,14 +129,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -121,14 +129,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec # proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be # and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens. # done by supporting per-sequence proposal lens.
spec_seqs, spec_indices = split_batch_by_proposal_len( (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
seq_group_metadata_list, split_batch_by_proposal_len(
proposal_lens_list, seq_group_metadata_list, proposal_lens_list)
select_proposal_len_zero=False)
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
seq_group_metadata_list,
proposal_lens_list,
select_proposal_len_zero=True)
target_seq_group_metadata_list = self._create_scoring_model_input( target_seq_group_metadata_list = self._create_scoring_model_input(
seq_group_metadata_list=spec_seqs, seq_group_metadata_list=spec_seqs,
...@@ -171,7 +174,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -171,7 +174,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is # The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for # equal to the total expanded batch size minus the number of samples for
# non-speculative sequences. # non-speculative sequences.
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape non_spec_expanded_bs = len(non_spec_target_token_ids)
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
...@@ -181,7 +184,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -181,7 +184,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if target_hidden_states is not None: if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape( target_hidden_states = target_hidden_states.reshape(
spec_expanded_bs, k + 1, target_hidden_states.shape[-1]) *target_token_ids.shape, target_hidden_states.shape[-1])
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1) fill_value=-1)
...@@ -196,24 +199,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -196,24 +199,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states = None all_hidden_states = None
if non_spec_indices: if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_tokens[non_spec_indices, :1] = \
all_probs[non_spec_indices, :1, :] = non_spec_target_probs non_spec_target_token_ids.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs all_probs[non_spec_indices, :1, :] = \
non_spec_target_probs.unsqueeze(1)
all_logprobs[non_spec_indices, :1, :] = \
non_spec_target_logprobs.unsqueeze(1)
if all_hidden_states is not None: if all_hidden_states is not None:
all_hidden_states[ assert non_spec_target_hidden_states is not None
non_spec_indices, :1, :] = non_spec_target_hidden_states all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states.unsqueeze(1)
if spec_indices: if spec_indices:
all_tokens[spec_indices] = target_token_ids all_tokens[spec_indices] = target_token_ids
all_probs[spec_indices] = target_probs all_probs[spec_indices] = target_probs
all_logprobs[spec_indices] = target_logprobs all_logprobs[spec_indices] = target_logprobs
if all_hidden_states is not None: if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states all_hidden_states[spec_indices] = target_hidden_states
return all_tokens, all_probs, all_logprobs, all_hidden_states return all_tokens, all_probs, all_logprobs, all_hidden_states
def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape
# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
target_probs = target_sampler_output.sampled_token_probs.reshape(
*target_token_ids.shape, self._vocab_size)
target_logprobs = target_sampler_output.logprobs.reshape(
target_probs.shape)
target_hidden_states = target_sampler_output.hidden_states
if target_hidden_states is not None:
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states)
def _create_scoring_model_input( def _create_scoring_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -345,8 +382,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -345,8 +382,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size=1, token_chunk_size=1,
) )
@staticmethod
def _split_scoring_output( def _split_scoring_output(
self, sampler_output: SamplerOutput, num_scoring_tokens: int sampler_output: SamplerOutput, num_scoring_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]: torch.Tensor, Optional[torch.Tensor]]:
...@@ -361,10 +399,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -361,10 +399,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# #
# First samples are from speculative scoring, latter samples are non- # First samples are from speculative scoring, latter samples are non-
# speculative samples. # speculative samples.
split_sizes = [ split_sizes = (num_scoring_tokens,
num_scoring_tokens, sampler_output.sampled_token_ids.numel() -
sampler_output.sampled_token_ids.numel() - num_scoring_tokens num_scoring_tokens)
]
(spec_probs, non_spec_probs (spec_probs, non_spec_probs
) = sampler_output.sampled_token_probs.split(split_sizes) ) = sampler_output.sampled_token_probs.split(split_sizes)
(spec_sampled_tokens, non_spec_sampled_tokens (spec_sampled_tokens, non_spec_sampled_tokens
...@@ -382,32 +419,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -382,32 +419,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else: else:
spec_hidden_states, non_spec_hidden_states = None, None spec_hidden_states, non_spec_hidden_states = None, None
# Convert scores to tensors. return (spec_sampled_tokens, spec_probs, spec_logprobs,
sampler_output.sampled_token_probs = spec_probs spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
sampler_output.sampled_token_ids = spec_sampled_tokens non_spec_logprobs, non_spec_hidden_states)
sampler_output.logprobs = spec_logprobs
sampler_output.hidden_states = spec_hidden_states
(target_token_ids, target_probs, target_logprobs,
target_hidden_states) = sampler_output_to_torch([sampler_output],
True)
# Convert non-speculative output tokens to tensors.
sampler_output.sampled_token_probs = non_spec_probs
sampler_output.sampled_token_ids = non_spec_sampled_tokens
sampler_output.logprobs = non_spec_logprobs
sampler_output.hidden_states = non_spec_hidden_states
(non_spec_target_token_ids, non_spec_target_probs,
non_spec_target_logprobs,
non_spec_target_hidden_states) = sampler_output_to_torch(
[sampler_output], True)
return (target_token_ids, target_probs, target_logprobs,
target_hidden_states, non_spec_target_token_ids,
non_spec_target_probs, non_spec_target_logprobs,
non_spec_target_hidden_states)
@staticmethod
def _create_target_seq_id_iterator( def _create_target_seq_id_iterator(
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
"""Create an iterator for creating target sequence ids. """Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored. distinct target sequence id for each proposal token to be scored.
...@@ -417,8 +435,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -417,8 +435,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
""" """
return count(start=max(seq_ids) + 1) return count(start=max(seq_ids) + 1)
@staticmethod
def _get_token_ids_to_score( def _get_token_ids_to_score(
self,
full_spec_token_ids: List[TokenId] # shape: [k] full_spec_token_ids: List[TokenId] # shape: [k]
) -> List[List[TokenId]]: ) -> List[List[TokenId]]:
"""Given an int tensor of proposal token ids, return a list of """Given an int tensor of proposal token ids, return a list of
...@@ -439,8 +457,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ...@@ -439,8 +457,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids: List[TokenId] = [] empty_token_ids: List[TokenId] = []
token_ids_to_score = [empty_token_ids] token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([ token_ids_to_score.extend(full_spec_token_ids[:i + 1]
full_spec_token_ids[:i + 1] for i in range(len(full_spec_token_ids)))
for i in range(len(full_spec_token_ids))
])
return token_ids_to_score return token_ids_to_score
...@@ -3,6 +3,7 @@ from typing import List, Optional ...@@ -3,6 +3,7 @@ from typing import List, Optional
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.sampler import SamplerOutput
try: try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.attention.backends.flash_attn import FlashAttentionMetadata
...@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, from vllm.sequence import ExecuteModelRequest, IntermediateTensors
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner) ModelRunner)
......
...@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple ...@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceGroupMetadata) from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
......
...@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple ...@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.model_executor.layers.sampler import SamplerOutput
SequenceGroupMetadata) from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment