Commit dcb5624a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-dev

parents 55880ca2 ba41cc90
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
LlamaForCausalLM)
from vllm.v1.sample.metadata import SamplingMetadata
from .utils import AutoWeightsLoader, maybe_prefix
logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config=quant_config, prefix=prefix)
# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
2 * self.hidden_size,
self.self_attn.head_dim,
self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "qkv_proj"),
)
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
embeds: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
# Fully Connected
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
)
])
if hasattr(self.config, "target_hidden_size"):
self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
self.config.hidden_size,
bias=False)
else:
self.fc = torch.nn.Linear(self.config.hidden_size * 3,
self.config.hidden_size,
bias=False)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
hidden_states = self.fc(hidden_states)
residual = None
hidden_states, residual = self.layers[0](
positions,
input_embeds,
hidden_states,
residual,
)
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if 'midlayer.' in name:
name = name.replace('midlayer.', 'layers.0.')
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
self.config.draft_vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.draft_vocab_size,
padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
prefix="")
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
scale=logit_scale)
self.draft_id_to_target_id = nn.Parameter(
torch.zeros((self.config.draft_vocab_size),
dtype=torch.long).type(torch.LongTensor),
requires_grad=False,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
targets = base + self.draft_id_to_target_id
logits_new = logits.new_full((
logits.shape[0],
self.config.vocab_size,
), float('-inf'))
logits_new[:, targets] = logits
return logits_new
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
elif "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast) TypeVar, Union, cast)
...@@ -23,7 +22,6 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -23,7 +22,6 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -546,13 +544,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -546,13 +544,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -763,13 +754,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -763,13 +754,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union) Protocol, Set, Tuple, TypedDict, TypeVar, Union)
...@@ -13,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -13,7 +12,6 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
...@@ -250,13 +248,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -250,13 +248,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -585,13 +576,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -585,13 +576,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
...@@ -12,7 +11,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig, ...@@ -12,7 +11,6 @@ from transformers import (BatchFeature, LlavaNextVideoConfig,
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -301,13 +299,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -301,13 +299,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, List[torch.Tensor]]:
...@@ -469,13 +460,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -469,13 +460,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -16,7 +15,6 @@ from typing_extensions import NotRequired ...@@ -16,7 +15,6 @@ from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -455,13 +453,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -455,13 +453,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -583,21 +574,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -583,21 +574,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {} mm_input_by_modality = {}
# Preserve the order of modalities if there are multiple of them # Preserve the order of modalities if there are multiple of them
# from the order of kwargs. # from the order of kwargs.
for input_key in kwargs: for input_key in kwargs:
if input_key in ("pixel_values", if input_key in ("pixel_values", "image_embeds"
"image_embeds") and "images" not in modalities: ) and "image" not in mm_input_by_modality:
modalities["images"] = self._parse_and_validate_image_input( mm_input_by_modality[
**kwargs) "image"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos", if input_key in ("pixel_values_videos", "video_embeds"
"video_embeds") and "videos" not in modalities: ) and "video" not in mm_input_by_modality:
modalities["videos"] = self._parse_and_validate_video_input( mm_input_by_modality[
**kwargs) "video"] = self._parse_and_validate_video_input(**kwargs)
return modalities return mm_input_by_modality
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
...@@ -848,8 +839,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -848,8 +839,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
if not modalities: **kwargs)
if not mm_input_by_modality:
return None return None
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
...@@ -858,14 +850,13 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -858,14 +850,13 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: It is important to iterate over the keys in this dictionary # NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities. # to preserve the order of the modalities.
for modality in modalities: for modality in mm_input_by_modality:
if modality == "images": multimodal_input = mm_input_by_modality[modality]
image_input = modalities["images"] if modality == "image":
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += tuple(vision_embeddings) multimodal_embeddings += tuple(vision_embeddings)
if modality == "videos": if modality == "video":
video_input = modalities["videos"] video_embeddings = self._process_video_pixels(multimodal_input)
video_embeddings = self._process_video_pixels(video_input)
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
...@@ -957,13 +948,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -957,13 +948,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -14,7 +14,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -27,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -27,7 +26,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType from vllm.utils import LayerBlockType
from .utils import (is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
...@@ -154,6 +153,26 @@ class MambaModel(nn.Module): ...@@ -154,6 +153,26 @@ class MambaModel(nn.Module):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
SupportsV0Only): SupportsV0Only):
...@@ -193,7 +212,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -193,7 +212,6 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
...@@ -247,30 +265,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -247,30 +265,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) loader = AutoWeightsLoader(self)
loaded_params: Set[str] = set() return loader.load_weights(weights)
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import ( ...@@ -19,7 +19,6 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards) MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -208,7 +207,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -208,7 +207,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.backbone.make_empty_intermediate_tensors) self.backbone.make_empty_intermediate_tensors)
...@@ -282,14 +280,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -282,14 +280,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -553,7 +552,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -553,7 +552,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(unpadded_vocab_size, self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -584,14 +582,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -584,14 +582,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
import math import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union) Union)
...@@ -40,7 +40,6 @@ from vllm.config import VllmConfig ...@@ -40,7 +40,6 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed) get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
...@@ -758,13 +757,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -758,13 +757,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors) self.llm.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def _parse_and_validate_vision_input( def _parse_and_validate_vision_input(
self, self,
modality: str, modality: str,
...@@ -946,14 +938,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -946,14 +938,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata) return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -33,7 +33,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -994,7 +993,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -994,7 +993,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.config.vocab_size) self.config.vocab_size)
self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
...@@ -1030,16 +1028,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1030,16 +1028,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors: device: torch.device) -> IntermediateTensors:
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union) TypeVar, Union)
...@@ -19,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -19,7 +18,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -274,6 +272,9 @@ class Mistral3MultiModalProcessor( ...@@ -274,6 +272,9 @@ class Mistral3MultiModalProcessor(
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig) assert isinstance(vision_config, PixtralVisionConfig)
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
encoder_info = PixtralHFEncoderInfo(vision_config) encoder_info = PixtralHFEncoderInfo(vision_config)
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
...@@ -435,13 +436,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -435,13 +436,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -598,13 +592,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -598,13 +592,6 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
......
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -489,7 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -489,7 +488,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -516,14 +514,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -516,14 +514,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
......
...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -42,7 +42,6 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -372,7 +371,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -372,7 +371,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
...@@ -399,14 +397,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -399,14 +397,6 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -47,7 +47,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -1211,7 +1210,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1211,7 +1210,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
self.logits_processor = LogitsProcessor(config.output_hidden_states, self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size) config.text_config.vocab_size)
self.sampler = get_sampler()
def compute_logits( def compute_logits(
self, self,
...@@ -1222,14 +1220,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1222,14 +1220,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states, sampling_metadata) hidden_states, sampling_metadata)
return logits return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def unpack_data(self, def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor], image_data: Union[List[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor: padding_value=0) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment