Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors ...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers, maybe_prefix from .utils import AutoWeightsLoader, make_layers, maybe_prefix
class GraniteMoeMoE(nn.Module): class GraniteMoeMoE(nn.Module):
...@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module): ...@@ -252,6 +252,8 @@ class GraniteMoeModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config # Required by MixtralModel
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module): ...@@ -304,6 +306,40 @@ class GraniteMoeModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -331,7 +367,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config # Required by MixtralForCausalLM
self.model = GraniteMoeModel(vllm_config=vllm_config, self.model = GraniteMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
...@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -403,37 +438,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
new_weights = {} loader = AutoWeightsLoader(
for n, p in weights: self,
if n.endswith('.block_sparse_moe.input_linear.weight'): skip_prefixes=(["lm_head."]
for e in range(p.size(0)): if self.config.tie_word_embeddings else None),
w1_name = n.replace( )
'.block_sparse_moe.input_linear.weight', return loader.load_weights(weights)
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
return mixtral.MixtralForCausalLM.load_weights(self,
new_weights.items())
...@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors ...@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .granitemoe import GraniteMoeAttention, GraniteMoeMoE from .granitemoe import GraniteMoeAttention, GraniteMoeMoE
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers, maybe_prefix from .utils import AutoWeightsLoader, make_layers, maybe_prefix
class GraniteMoeSharedMLP(nn.Module): class GraniteMoeSharedMLP(nn.Module):
...@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -152,6 +152,8 @@ class GraniteMoeSharedModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config # Required by MixtralModel
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -207,6 +209,40 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
for e in range(p.size(0)):
w1_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
else:
new_weights[n] = p
return mixtral.MixtralModel.load_weights(self, new_weights.items())
class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -234,7 +270,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config
self.model = GraniteMoeSharedModel(vllm_config=vllm_config, self.model = GraniteMoeSharedModel(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
...@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -307,37 +342,9 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
new_weights = {} loader = AutoWeightsLoader(
for n, p in weights: self,
if n.endswith('.block_sparse_moe.input_linear.weight'): skip_prefixes=(["lm_head."]
for e in range(p.size(0)): if self.config.tie_word_embeddings else None),
w1_name = n.replace( )
'.block_sparse_moe.input_linear.weight', return loader.load_weights(weights)
f".block_sparse_moe.experts.{e}.w1.weight")
w3_name = n.replace(
'.block_sparse_moe.input_linear.weight',
f".block_sparse_moe.experts.{e}.w3.weight")
w1_param, w3_param = p[e].chunk(2, dim=0)
assert w1_name not in new_weights
assert w3_name not in new_weights
new_weights[w1_name] = w1_param
new_weights[w3_name] = w3_param
elif n.endswith('.block_sparse_moe.output_linear.weight'):
for e in range(p.size(0)):
w2_name = n.replace(
'.block_sparse_moe.output_linear.weight',
f".block_sparse_moe.experts.{e}.w2.weight")
w2_param = p[e]
assert w2_name not in new_weights
new_weights[w2_name] = w2_param
elif n.endswith('.block_sparse_moe.router.layer.weight'):
gate_name = n.replace('.block_sparse_moe.router.layer.weight',
".block_sparse_moe.gate.weight")
assert gate_name not in new_weights
new_weights[gate_name] = p
elif n == 'lm_head.weight' and self.config.tie_word_embeddings:
pass
else:
new_weights[n] = p
return mixtral.MixtralForCausalLM.load_weights(self,
new_weights.items())
...@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -302,6 +302,8 @@ class Grok1Model(nn.Module): ...@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -370,94 +372,6 @@ class Grok1Model(nn.Module): ...@@ -370,94 +372,6 @@ class Grok1Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Grok1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.output_multiplier_scale = getattr(
config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -480,9 +394,6 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales # Loading kv cache quantization scales
...@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -553,13 +464,107 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "norm.scale" in name: if "norm.scale" in name:
name = name.replace("scale", "weight") name = name.replace("scale", "weight")
# Skip lm_head when tie_word_embeddings is True
if "lm_head" in name and self.config.tie_word_embeddings:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Grok1Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.output_multiplier_scale = getattr(
config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
self.output_multiplier_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings:
skip_prefixes.append("lm_head")
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num( def resolve_min_max_num(
self, self,
...@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
**kwargs, **kwargs,
) )
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
if mm_counts.get("image", 0) <= 1:
max_tokens_per_image = max_tokens_one_image
else:
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
return {"image": max_tokens_per_image}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
use_msac=use_msac, use_msac=use_msac,
) )
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
use_msac=use_msac,
)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
): ):
......
...@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -32,18 +32,18 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalDataItems, MultiModalDataItems, PromptReplacement,
MultiModalFieldConfig, PromptUpdate, PromptUpdateDetails)
PromptReplacement, PromptUpdate,
encode_tokens)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
# yapf: disable # yapf: disable
...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal ...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImagePixelInputs(TypedDict):
...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict): ...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict): class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict): ...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -114,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self, def _resize_output_size(self,
*, *,
height: int, height: int,
...@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -223,6 +199,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
return grid_w * grid_h + 1 return grid_w * grid_h + 1
def _get_image_token(
self,
processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token.content
fake_image_token = processor.fake_image_token.content
global_image_token = processor.global_image_tag
return image_token, fake_image_token, global_image_token
def get_image_repl( def get_image_repl(
self, self,
*, *,
...@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -233,9 +219,8 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_token = processor.image_token.content image_token, fake_image_token, global_img_token = self._get_image_token(
fake_image_token = processor.fake_image_token.content processor)
global_img_token = processor.global_image_tag
image_seq_len = processor.image_seq_len image_seq_len = processor.image_seq_len
grid_placeholder = "<row_{n_h}_col_{n_w}>" grid_placeholder = "<row_{n_h}_col_{n_w}>"
...@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -275,19 +260,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
processor: Optional[Idefics3Processor], processor: Optional[Idefics3Processor],
) -> int: ) -> int:
tokenizer = self.get_tokenizer() if processor is None:
image_repl = self.get_image_repl( processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
) )
image_repl_tokens = encode_tokens( return num_patches * processor.image_seq_len
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -298,42 +280,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height=image_processor.size["longest_edge"], height=image_processor.size["longest_edge"],
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
): ):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token, _, _ = self.info._get_image_token(processor)
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
image_processor: Idefics3ImageProcessor = hf_processor.image_processor image_processor: Idefics3ImageProcessor = hf_processor.image_processor
longest_edge = image_processor.max_image_size['longest_edge'] longest_edge = image_processor.max_image_size['longest_edge']
image_token = hf_processor.image_token.content
mm_data = { return {
"image": "image":
self._get_dummy_images(width=longest_edge, self._get_dummy_images(width=longest_edge,
height=longest_edge, height=longest_edge,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Idefics3MultiModalProcessor( class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]): BaseMultiModalProcessor[Idefics3ProcessingInfo]):
...@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor( ...@@ -364,28 +339,6 @@ class Idefics3MultiModalProcessor(
] ]
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [ num_patches = [
self.info.get_num_patches( self.info.get_num_patches(
image_width=size.width, image_width=size.width,
...@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor( ...@@ -415,7 +368,6 @@ class Idefics3MultiModalProcessor(
"image", num_patches), "image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor( ...@@ -425,19 +377,24 @@ class Idefics3MultiModalProcessor(
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content image_token, _, _ = self.info._get_image_token(hf_processor)
def get_replacement_idefics3(item_idx: int) -> str: def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.info.get_image_repl( image_repl = self.info.get_image_repl(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
) )
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
...@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -675,13 +632,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
...@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -690,7 +640,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs( return Idefics3ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
) )
if pixel_values is not None: if pixel_values is not None:
...@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -718,7 +667,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask, pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches, num_patches=num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -748,18 +696,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
] ]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -771,7 +717,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_id, self.config.image_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -60,6 +60,18 @@ class SupportsMultiModal(Protocol): ...@@ -60,6 +60,18 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill # Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated # TODO(ywang96): Remove this overload once v0 is deprecated
@overload @overload
......
...@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel, ...@@ -25,21 +25,20 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC): ...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]), torch.tensor([len(item) for item in pixel_values_lst]),
} }
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0] num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches) image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text] text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor): ...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseInternVLProcessingInfo(BaseProcessingInfo): class BaseInternVLProcessingInfo(BaseProcessingInfo):
...@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -479,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -501,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
image_height=image_height, image_height=image_height,
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo) ...@@ -541,27 +504,27 @@ _I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
return "<image>" * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
...@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -599,7 +562,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -831,7 +793,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
...@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -860,20 +821,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. " raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}") f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=self._validate_pixel_values( pixel_values_flat=self._validate_pixel_values(
pixel_values_flat), pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -913,21 +868,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -941,7 +891,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_context_token_id, self.img_context_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -309,7 +309,7 @@ class LlamaModel(nn.Module): ...@@ -309,7 +309,7 @@ class LlamaModel(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -637,14 +637,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -637,14 +637,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config, self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"), prefix=maybe_prefix(prefix, "model"),
layer_type=layer_type) layer_type=layer_type)
...@@ -681,12 +680,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -681,12 +680,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _init_model(self, def _init_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
return LlamaModel(vllm_config=vllm_config, return LlamaModel(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
layer_type=layer_type) layer_type=layer_type)
...@@ -719,7 +717,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -719,7 +717,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
...@@ -731,7 +728,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -731,7 +728,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.maybe_remap_mistral(name, loaded_weight) self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights) for name, loaded_weight in weights)
# This function is used to remap the mistral format as # This function is used to remap the mistral format as
# used by Mistral and Llama <=2 # used by Mistral and Llama <=2
def maybe_remap_mistral( def maybe_remap_mistral(
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -36,8 +36,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
is_pp_missing_parameter) is_pp_missing_parameter)
...@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module): ...@@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1) router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to( router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype) hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module): ...@@ -155,14 +155,8 @@ class Llama4Attention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads self.n_rep = self.num_heads // self.num_kv_heads
self.q_norm = RMSNorm( self.qk_norm = RMSNorm(
hidden_size=self.q_size, hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.k_norm = RMSNorm(
hidden_size=self.kv_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
has_weight=False, has_weight=False,
dtype=torch.float32, dtype=torch.float32,
...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module): ...@@ -226,10 +220,11 @@ class Llama4Attention(nn.Module):
if self.rotary_emb is not None: if self.rotary_emb is not None:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
if self.q_norm is not None: if self.qk_norm is not None:
q = self.q_norm(q.float()).to(q.dtype) q = q.reshape(-1, self.num_heads, self.head_dim)
if self.k_norm is not None: q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
k = self.k_norm(k.float()).to(k.dtype) k = k.reshape(-1, self.num_kv_heads, self.head_dim)
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
# to NoPE layers, where the inference-time temperature tuning function # to NoPE layers, where the inference-time temperature tuning function
...@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module): ...@@ -247,7 +242,7 @@ class Llama4Attention(nn.Module):
return output return output
class Llama4DecoderLayer(LlamaDecoderLayer): class Llama4DecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer): ...@@ -256,8 +251,9 @@ class Llama4DecoderLayer(LlamaDecoderLayer):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
nn.Module.__init__(self)
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
...@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel): ...@@ -329,7 +325,7 @@ class Llama4Model(LlamaModel):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
...@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -471,20 +467,24 @@ class Llama4ForCausalLM(LlamaForCausalLM):
} }
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config # update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config() gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config) gen_config.update(vllm_config.model_config.override_generation_config)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning = \
vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning \ vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False) = gen_config.get(
LlamaForCausalLM.__init__(self, "attn_temperature_tuning", default_attn_temperature_tuning)
vllm_config=vllm_config,
prefix=prefix, super().__init__(vllm_config=vllm_config,
layer_type=Llama4DecoderLayer) prefix=prefix,
layer_type=Llama4DecoderLayer)
def _init_model(self, def _init_model(self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[Llama4DecoderLayer] = Llama4DecoderLayer): layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config, return Llama4Model(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
layer_type=layer_type) layer_type=layer_type)
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple
import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
disable_input_layernorm: bool,
prefix: str = "",
) -> None:
super().__init__(config, prefix=prefix)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if disable_input_layernorm:
del self.input_layernorm
self.input_layernorm = nn.Identity()
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
return hidden_states + residual
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
...@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -32,8 +32,9 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel ...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features, from .vision import get_vision_encoder_info
select_patch_features)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -145,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
strategy: str, strategy: str,
...@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) ...@@ -201,30 +186,31 @@ _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class LlavaProcessingInfo(BaseLlavaProcessingInfo): class LlavaProcessingInfo(BaseLlavaProcessingInfo):
...@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor( ...@@ -343,23 +329,6 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes) for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor( ...@@ -369,7 +338,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor( ...@@ -404,7 +372,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -612,17 +580,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -708,22 +668,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes) image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -735,7 +689,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -16,13 +16,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -61,22 +62,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1} return {"video": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
...@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -146,22 +131,27 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
class LlavaNextVideoDummyInputsBuilder( class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]): BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
video_token = processor.video_token video_token = processor.video_token
return video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_videos = mm_counts.get("video", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { return {
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
...@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder( ...@@ -171,11 +161,6 @@ class LlavaNextVideoDummyInputsBuilder(
) )
} }
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
class LlavaNextVideoMultiModalProcessor( class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]): BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
...@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -421,6 +406,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds] return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
......
...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -19,11 +19,11 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor # with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features( def _get_num_unpadded_features(
...@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -236,11 +226,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
class LlavaOnevisionDummyInputsBuilder( class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]): LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder( ...@@ -248,13 +234,23 @@ class LlavaOnevisionDummyInputsBuilder(
image_token = processor.image_token image_token = processor.image_token
video_token = processor.video_token video_token = processor.video_token
return image_token * num_images + video_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, self.info.get_num_frames_with_most_features(seq_len,
mm_counts) mm_counts)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
...@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder( ...@@ -268,11 +264,6 @@ class LlavaOnevisionDummyInputsBuilder(
) )
} }
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProcessor( class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]): BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
...@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -852,6 +843,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim) image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
...@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group ...@@ -13,6 +13,8 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import ( from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -57,7 +59,6 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon, rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act, activation=config.hidden_act,
chunk_size=config.chunk_size,
quant_config=quant_config) quant_config=quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -67,7 +68,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor], mamba2_metadata: Mamba2Metadata,
**kwargs, **kwargs,
): ):
if residual is None: if residual is None:
...@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module): ...@@ -77,7 +78,7 @@ class Mamba2DecoderLayer(nn.Module):
hidden_states, residual = self.norm(hidden_states, residual) hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params, hidden_states = self.mixer(hidden_states, mamba_cache_params,
sequence_idx) mamba2_metadata)
return hidden_states, residual return hidden_states, residual
...@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module): ...@@ -138,20 +139,13 @@ class Mamba2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
# pass a sequence index tensor, that is required for
# proper continuous batching computation including
# chunked prefill
seq_idx = None
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
if attn_metadata.num_prefills > 0:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) mamba2_metadata = prepare_mamba2_metadata(
for i, (srt, end) in enumerate( chunk_size=self.config.chunk_size,
zip( input_ids=input_ids,
attn_metadata.query_start_loc, attn_metadata=attn_metadata,
attn_metadata.query_start_loc[1:], )
)):
seq_idx[srt:end] = i
seq_idx.unsqueeze_(0)
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module): ...@@ -162,7 +156,7 @@ class Mamba2Model(nn.Module):
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer), i - self.start_layer),
sequence_idx=seq_idx) mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
......
...@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import ( ...@@ -35,13 +35,14 @@ from transformers.models.whisper.modeling_whisper import (
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData, DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import ProcessorInputs PromptUpdateDetails)
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder, MiniCPMVDummyInputsBuilder,
...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, ...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix) maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict): ...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list. Length of each slice may vary, so pass it as a list.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs] MiniCPMOAudioEmbeddingInputs]
...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
...@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): ...@@ -143,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data( def _parse_audio_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]], data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems( return MiniCPMOAudioEmbeddingItems(
data, data,
...@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -159,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {**super().get_supported_mm_limits(), "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder( def get_audio_placeholder(
self, self,
audio_lens: int, audio_lens: int,
...@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -197,8 +169,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step() pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100 fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 return (cnn_feat_in_chunk - pool_step) // pool_step + 1
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
def get_max_audio_chunks_with_most_features(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int:
return 30 return 30
...@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -209,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio> num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features( def get_num_frames_with_most_features(
...@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -236,29 +206,31 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
class MiniCPMODummyInputsBuilder( class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]): MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
self, seq_len: int, mm_counts: Mapping[str, num_audios = mm_counts.get("audio", 0)
int]) -> ProcessorInputs:
audio_prompt_texts = self.info.audio_pattern * num_audios
return super().get_dummy_text(mm_counts) + audio_prompt_texts
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_len = self.info.get_max_audio_chunks_with_most_features() * \ audio_len = self.info.get_max_audio_chunks_with_most_features() * \
self.info.get_default_audio_sampling_rate() self.info.get_default_audio_sampling_rate()
processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts)
audio_prompt_texts = self.info.audio_pattern * num_audios
audio_mm_data = { audio_mm_data = {
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len, num_audios=num_audios)
} }
return ProcessorInputs( return {
prompt_text=processor_inputs.prompt_text + audio_prompt_texts, **super().get_dummy_mm_data(seq_len, mm_counts),
mm_data={ **audio_mm_data,
**processor_inputs.mm_data, }
**audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor( class MiniCPMOMultiModalProcessor(
...@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor( ...@@ -295,13 +267,6 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {} audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else: else:
audio_inputs = self._base_call_hf_processor( audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios), prompts=[self.info.audio_pattern] * len(parsed_audios),
...@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor( ...@@ -323,27 +288,7 @@ class MiniCPMOMultiModalProcessor(
] ]
audio_inputs["audio_features"] = unpadded_audio_features audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
...@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -384,7 +329,10 @@ class MiniCPMOMultiModalProcessor(
else: else:
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
return self.get_audio_prompt_texts(audio_len) return PromptUpdateDetails.select_text(
self.get_audio_prompt_texts(audio_len),
"<unk>",
)
return [ return [
*base_updates, *base_updates,
...@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -713,13 +661,6 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor) assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item()) self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. " raise ValueError("Incorrect type of audio_embeds. "
...@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -730,7 +671,6 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=audio_embeds_flat, audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
) )
if not isinstance(audio_features, (torch.Tensor, list)): if not isinstance(audio_features, (torch.Tensor, list)):
...@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -749,7 +689,6 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features", type="audio_features",
audio_features=audio_features_flat, audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat, audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
...@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -781,10 +720,6 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(audio_features)
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings return multimodal_embeddings
...@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys ...@@ -48,7 +48,8 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
ImageProcessorItems, ImageSize, ImageProcessorItems, ImageSize,
ModalityData, ModalityDataItems, ModalityData, ModalityDataItems,
...@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ...@@ -56,8 +57,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems) VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists from vllm.utils import flatten_2d_lists
...@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -67,7 +68,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
...@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict): ...@@ -90,14 +90,6 @@ class MiniCPMVImagePixelInputs(TypedDict):
This should be in `(height, width)` format. This should be in `(height, width)` format.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
num_slices: torch.Tensor num_slices: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
...@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict): ...@@ -112,14 +104,6 @@ class MiniCPMVImageEmbeddingInputs(TypedDict):
instead of a batched tensor. instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
MiniCPMVImageEmbeddingInputs] MiniCPMVImageEmbeddingInputs]
...@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -245,12 +229,10 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.batched("image"), tgt_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"), video_pixel_values=MultiModalFieldConfig.batched("video"),
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"),
video_embed_is_patch=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos), video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
...@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -308,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data( def _parse_image_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems( return MiniCPMVImageEmbeddingItems(
data, data,
...@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser): ...@@ -320,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_video_data( def _parse_video_data(
self, self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]: ) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict): if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems( return MiniCPMVVideoEmbeddingItems(
data, data,
...@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -365,18 +347,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return mm_limits return mm_limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(
seq_len, mm_counts)
return mm_max_tokens
def get_slice_image_placeholder( def get_slice_image_placeholder(
self, self,
image_size: ImageSize, image_size: ImageSize,
...@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -398,22 +368,43 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
use_image_id=use_image_id, use_image_id=use_image_id,
) )
def get_sliced_grid(
self,
image_size: ImageSize,
# For MiniCPM V/O 2.6
max_slice_nums: Optional[int] = None,
) -> Optional[tuple[int, int]]:
image_processor = self.get_image_processor()
version = self.get_model_version()
if version == (2, 0) or version == (2, 5):
return image_processor.get_sliced_grid(image_size)
if max_slice_nums is None:
max_slice_nums = image_processor.max_slice_nums
return image_processor.get_sliced_grid(
image_size,
max_slice_nums=max_slice_nums,
)
def get_num_image_tokens( def get_num_image_tokens(
self, self,
image_size: ImageSize, image_size: ImageSize,
max_slice_nums: Optional[int] = None, max_slice_nums: Optional[int] = None,
use_image_id: bool = True,
) -> int: ) -> int:
tokenizer = self.get_tokenizer() image_processor = self.get_image_processor()
image_placeholders = self.get_slice_image_placeholder(
grid = self.get_sliced_grid(
image_size, image_size,
max_slice_nums=max_slice_nums, max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
) )
image_token_ids = tokenizer.encode(image_placeholders, if grid is None:
add_special_tokens=False) ncols = nrows = 0
else:
ncols, nrows = grid
return len(image_token_ids) return (ncols * nrows + 1) * image_processor.image_feature_size
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_size = self.get_image_size_with_most_features() image_size = self.get_image_size_with_most_features()
...@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -433,7 +424,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens( return self.get_num_image_tokens(
frame_size, frame_size,
max_slice_nums=self.get_video_max_slice_num(), max_slice_nums=self.get_video_max_slice_num(),
use_image_id=False,
) )
def get_max_video_tokens( def get_max_video_tokens(
...@@ -482,11 +472,20 @@ _I = TypeVar("_I", ...@@ -482,11 +472,20 @@ _I = TypeVar("_I",
class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return image_prompt_texts + video_prompt_texts
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -497,7 +496,7 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
num_video_frames = \ num_video_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
mm_data = { return {
"image": "image":
self._get_dummy_images(width=image_width, self._get_dummy_images(width=image_width,
height=image_height, height=image_height,
...@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -509,13 +508,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
] * num_videos, ] * num_videos,
} }
image_prompt_texts = self.info.image_pattern * num_images
video_prompt_texts = self.info.video_pattern * num_videos
return ProcessorInputs(prompt_text=image_prompt_texts +
video_prompt_texts,
mm_data=mm_data)
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
...@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -539,14 +531,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
use_image_id=False, use_image_id=False,
) * num_frames ) * num_frames
def get_embed_is_patch(
self,
input_ids: list[int],
) -> torch.Tensor:
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
return torch.tensor(input_ids) == unk_token_id
def process_images( def process_images(
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
...@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -570,26 +554,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
image_repl_features = [
self.get_image_prompt_texts(size, idx)
for idx, size in enumerate(image_sizes)
]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(image_repl_tokens)
for image_repl_tokens in image_repls_feature_tokens
]
image_inputs["embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id) image_inputs["image_token_id"] = torch.tensor(unk_token_id)
...@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -625,31 +590,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
frame_sizes = [
parsed_videos.get_frame_size(i) for i in range(len(parsed_videos))
]
num_frames = [
parsed_videos.get_num_frames(i) for i in range(len(parsed_videos))
]
video_repl_features = [
self.get_video_prompt_texts(size, nframes)
for size, nframes in zip(frame_sizes, num_frames)
]
tokenizer = self.info.get_tokenizer()
video_repls_feature_tokens = [
tokenizer.encode(video_repl, add_special_tokens=False)
for video_repl in video_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(video_repl_tokens)
for video_repl_tokens in video_repls_feature_tokens
]
video_inputs["embed_is_patch"] = embed_is_patch
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id) video_inputs["video_token_id"] = torch.tensor(unk_token_id)
...@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -740,7 +683,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.get_image_prompt_texts(image_size, item_idx) return PromptUpdateDetails.select_text(
self.get_image_prompt_texts(image_size, item_idx),
"<unk>",
)
def get_video_replacement(item_idx: int): def get_video_replacement(item_idx: int):
videos = mm_items.get_items( videos = mm_items.get_items(
...@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -749,7 +695,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
frame_size = videos.get_frame_size(item_idx) frame_size = videos.get_frame_size(item_idx)
num_frames = videos.get_num_frames(item_idx) num_frames = videos.get_num_frames(item_idx)
return self.get_video_prompt_texts(frame_size, num_frames) return PromptUpdateDetails.select_text(
self.get_video_prompt_texts(frame_size, num_frames),
"<unk>",
)
get_replacement = { get_replacement = {
"image": get_image_replacement, "image": get_image_replacement,
...@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -832,14 +781,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
assert isinstance(image_token_id, torch.Tensor) assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item()) self.mm_token_ids.add(image_token_id.flatten().unique().item())
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of embed_is_patch for {modality=}. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError( raise ValueError(
...@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -851,7 +792,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=image_embeds_flat, image_embeds=image_embeds_flat,
embed_is_patch=embed_is_patch,
) )
if not isinstance(pixel_values, (torch.Tensor, list)): if not isinstance(pixel_values, (torch.Tensor, list)):
...@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -879,7 +819,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
type="pixel_values", type="pixel_values",
pixel_values=pixel_values_flat, pixel_values=pixel_values_flat,
tgt_sizes=tgt_sizes_flat, tgt_sizes=tgt_sizes_flat,
embed_is_patch=embed_is_patch,
num_slices=num_slices_flat, num_slices=num_slices_flat,
) )
...@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -936,22 +875,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
image_features = self._process_vision_input(image_input) image_features = self._process_vision_input(image_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(image_features)
scatter_patch_features(
image_features,
image_input["embed_is_patch"],
))
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_features = self._process_vision_input(video_input) video_features = self._process_vision_input(video_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(video_features)
scatter_patch_features(
video_features,
video_input["embed_is_patch"],
))
return multimodal_embeddings return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
...@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -971,7 +905,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
list(self.mm_token_ids), list(self.mm_token_ids),
) )
return inputs_embeds return inputs_embeds
......
...@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -22,21 +22,22 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features, from .vision import get_vision_encoder_info
select_patch_features)
class Mistral3ImagePixelInputs(TypedDict): class Mistral3ImagePixelInputs(TypedDict):
...@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict): ...@@ -49,14 +50,6 @@ class Mistral3ImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
class Mistral3PatchMerger(nn.Module): class Mistral3PatchMerger(nn.Module):
""" """
...@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -170,13 +163,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
...@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): ...@@ -194,44 +180,37 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]): class Mistral3DummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
mm_data = { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images)
} }
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class Mistral3ProcessingInfo(BaseLlavaProcessingInfo): class Mistral3ProcessingInfo(BaseLlavaProcessingInfo):
...@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor( ...@@ -266,23 +245,6 @@ class Mistral3MultiModalProcessor(
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes) p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor( ...@@ -292,7 +254,6 @@ class Mistral3MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor( ...@@ -327,7 +288,7 @@ class Mistral3MultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -418,8 +379,6 @@ def init_vision_tower_for_llava( ...@@ -418,8 +379,6 @@ def init_vision_tower_for_llava(
) )
# TODO(mgoin): Support V1, there are issues with image batching/chunking
# that need to be resolved first.
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
_build_mistral3_processor, _build_mistral3_processor,
info=_build_mistral3_info, info=_build_mistral3_info,
...@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -509,16 +468,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of pixel values. " raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
assert self.config.vision_config.model_type == "pixtral"
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
return Mistral3ImagePixelInputs( return Mistral3ImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=flatten_bn(embed_is_patch),
) )
def _process_image_input( def _process_image_input(
...@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -549,6 +501,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, ) image_embeds = (image_embeds, )
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
...@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -557,10 +512,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return scatter_patch_features( return vision_embeddings
vision_embeddings,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -572,7 +524,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -53,7 +53,7 @@ from vllm import _custom_ops as ops ...@@ -53,7 +53,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -76,7 +76,7 @@ class MixtralMoE(nn.Module): ...@@ -76,7 +76,7 @@ class MixtralMoE(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
dp_size: Optional[int] = None, dp_size: Optional[int] = None,
prefix: str = "",): prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -99,7 +99,7 @@ class MixtralMoE(nn.Module): ...@@ -99,7 +99,7 @@ class MixtralMoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
dp_size=dp_size, dp_size=dp_size,
prefix=f"{prefix}.experts",) prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
...@@ -222,7 +222,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -222,7 +222,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe",) prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, self.post_attention_layernorm = RMSNorm(config.hidden_size,
...@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MixtralModel(nn.Module): class MixtralModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "",): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -264,6 +264,8 @@ class MixtralModel(nn.Module): ...@@ -264,6 +264,8 @@ class MixtralModel(nn.Module):
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
...@@ -278,7 +280,7 @@ class MixtralModel(nn.Module): ...@@ -278,7 +280,7 @@ class MixtralModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: MixtralDecoderLayer( lambda prefix: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix, config, cache_config, quant_config=quant_config, prefix=prefix
), ),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
...@@ -286,6 +288,12 @@ class MixtralModel(nn.Module): ...@@ -286,6 +288,12 @@ class MixtralModel(nn.Module):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -317,97 +325,6 @@ class MixtralModel(nn.Module): ...@@ -317,97 +325,6 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.parallel_config = vllm_config.parallel_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -428,9 +345,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -428,9 +345,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales # Loading kv cache quantization scales
...@@ -475,7 +389,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -475,7 +389,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if ((name.endswith(".bias") or name.endswith("_bias")) if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict): and name not in params_dict):
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -527,3 +440,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -527,3 +440,90 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
return loader.load_weights(weights)
...@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope ...@@ -45,7 +45,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -420,6 +421,11 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment