Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# but implemented by the Phi-Speech team # but implemented by the Phi-Speech team
#!/usr/bin/env python3 #!/usr/bin/env python3
import math import math
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -1586,7 +1586,7 @@ class AttModule(nn.Module): ...@@ -1586,7 +1586,7 @@ class AttModule(nn.Module):
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None, pos_emb: Optional[Tensor] = None,
att_mask: Optional[Tensor] = None, att_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""AttModule forward """AttModule forward
Args: Args:
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only PhiMoE model.""" """Inference-only PhiMoE model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -505,8 +506,8 @@ class PhiMoEModel(nn.Module): ...@@ -505,8 +506,8 @@ class PhiMoEModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -521,7 +522,7 @@ class PhiMoEModel(nn.Module): ...@@ -521,7 +522,7 @@ class PhiMoEModel(nn.Module):
num_experts=self.config.num_local_experts) num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
...@@ -657,8 +658,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -657,8 +658,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["rotary_emb.inv_freq"]), skip_prefixes=(["rotary_emb.inv_freq"]),
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -438,18 +438,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -438,18 +438,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder") return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter") return weight[0].startswith("vision_language_adapter")
def is_patch_merger(weight: Tuple[str, torch.Tensor]): def is_patch_merger(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger") return weight[0].startswith("patch_merger")
def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]): def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm") return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading # Get references to parameters for direct loading
...@@ -566,7 +566,7 @@ def apply_rotary_emb_vit( ...@@ -566,7 +566,7 @@ def apply_rotary_emb_vit(
xq: torch.Tensor, xq: torch.Tensor,
xk: torch.Tensor, xk: torch.Tensor,
freqs_cis: torch.Tensor, freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64 assert freqs_cis.dtype == torch.complex64
...@@ -671,7 +671,7 @@ class Transformer(nn.Module): ...@@ -671,7 +671,7 @@ class Transformer(nn.Module):
return x return x
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([ positions = torch.cat([
torch.stack( torch.stack(
torch.meshgrid( torch.meshgrid(
...@@ -733,7 +733,7 @@ class VisionTransformer(nn.Module): ...@@ -733,7 +733,7 @@ class VisionTransformer(nn.Module):
def forward( def forward(
self, self,
images: List[torch.Tensor], images: list[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -1023,7 +1023,7 @@ class PixtralHFAttention(nn.Module): ...@@ -1023,7 +1023,7 @@ class PixtralHFAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
position_embeddings: torch.Tensor, position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size() batch, patches, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states) qkv_states, _ = self.qkv_proj(hidden_states)
...@@ -1249,8 +1249,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1249,8 +1249,8 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1260,7 +1260,7 @@ class PixtralHFVisionModel(nn.Module): ...@@ -1260,7 +1260,7 @@ class PixtralHFVisionModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.transformer.layers) layer_count = len(self.transformer.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only PLaMo2 model.""" """Inference-only PLaMo2 model."""
import math import math
from typing import Iterable, Optional, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -659,7 +660,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, ...@@ -659,7 +660,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
hidden_size = (self.config.mamba_num_heads * hidden_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head) self.config.hidden_size_per_head)
...@@ -682,7 +683,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid, ...@@ -682,7 +683,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -154,7 +154,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -154,7 +154,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"by PrithviGeospatialMAE.") "by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data( def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor): if not isinstance(pixel_values, torch.Tensor):
...@@ -195,8 +195,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, ...@@ -195,8 +195,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_list = [] params_list = []
model_buffers = dict(self.named_buffers()) model_buffers = dict(self.named_buffers())
loaded_buffers = [] loaded_buffers = []
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights.""" """Inference-only QWen model compatible with HuggingFace weights."""
import json import json
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -76,7 +77,7 @@ class QWenAttention(nn.Module): ...@@ -76,7 +77,7 @@ class QWenAttention(nn.Module):
num_heads: int, num_heads: int,
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -166,7 +167,7 @@ class QWenBlock(nn.Module): ...@@ -166,7 +167,7 @@ class QWenBlock(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -284,15 +285,15 @@ class QWenBaseModel(nn.Module): ...@@ -284,15 +285,15 @@ class QWenBaseModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0), ("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1), ("gate_up_proj", "w1", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -108,7 +109,7 @@ class Qwen2Attention(nn.Module): ...@@ -108,7 +109,7 @@ class Qwen2Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None, rope_scaling: Optional[tuple] = None,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None, dual_chunk_attention_config: Optional[dict[str, Any]] = None,
...@@ -245,7 +246,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -245,7 +246,7 @@ class Qwen2DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -367,8 +368,8 @@ class Qwen2Model(nn.Module): ...@@ -367,8 +368,8 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -378,7 +379,7 @@ class Qwen2Model(nn.Module): ...@@ -378,7 +379,7 @@ class Qwen2Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -490,8 +491,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -490,8 +491,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
...@@ -559,7 +560,7 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -559,7 +560,7 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights weights = ((name, data) for name, data in weights
if not name.startswith("lm_head.")) if not name.startswith("lm_head."))
......
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-Omni model (thinker part).""" """Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Iterable, Mapping, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence, from typing import Any, Optional, Union
Set, Tuple, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -138,7 +138,7 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, ...@@ -138,7 +138,7 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, List[float]]] = None, fps: Optional[Union[float, list[float]]] = None,
**kwargs: object, **kwargs: object,
) -> Qwen2_5OmniProcessor: ) -> Qwen2_5OmniProcessor:
if fps is not None: if fps is not None:
...@@ -550,7 +550,7 @@ class Qwen2_5OmniConditionalGenerationMixin: ...@@ -550,7 +550,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, self,
**kwargs: Dict[str, Any], **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLImageInputs]: ) -> Optional[Qwen2_5_VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
...@@ -589,7 +589,7 @@ class Qwen2_5OmniConditionalGenerationMixin: ...@@ -589,7 +589,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, self,
**kwargs: Dict[str, Any], **kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLVideoInputs]: ) -> Optional[Qwen2_5_VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None) pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None) video_embeds = kwargs.pop("video_embeds", None)
...@@ -627,7 +627,7 @@ class Qwen2_5OmniConditionalGenerationMixin: ...@@ -627,7 +627,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input( def _process_audio_input(
self, self,
audio_input: Qwen2AudioInputs, audio_input: Qwen2AudioInputs,
audio_hashes: List[str] = None, audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None, cached_audio_features: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -676,7 +676,7 @@ class Qwen2_5OmniConditionalGenerationMixin: ...@@ -676,7 +676,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_video_input( def _process_video_input(
self, self,
video_input: Qwen2_5_VLVideoInputs, video_input: Qwen2_5_VLVideoInputs,
video_hashes: List[str] = None, video_hashes: list[str] = None,
cached_video_embeds: torch.Tensor = None) -> torch.Tensor: cached_video_embeds: torch.Tensor = None) -> torch.Tensor:
if video_input["type"] == "video_embeds": if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype) return video_input["video_embeds"].type(self.visual.dtype)
...@@ -825,7 +825,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -825,7 +825,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
if audio_input is None and image_input is None and video_input is None: if audio_input is None and image_input is None and video_input is None:
return None return None
multimodal_embeddings: List[Tuple[NestedTensors, str]] = [] multimodal_embeddings: list[tuple[NestedTensors, str]] = []
if audio_input is not None: if audio_input is not None:
audio_embeds = self._process_audio_input(audio_input) audio_embeds = self._process_audio_input(audio_input)
...@@ -891,8 +891,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -891,8 +891,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=["talker.", "token2wav."], skip_prefixes=["talker.", "token2wav."],
......
...@@ -24,9 +24,9 @@ ...@@ -24,9 +24,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping
from functools import partial from functools import partial
from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, from typing import Callable, Literal, Optional, TypedDict, Union
Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -91,7 +91,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TypedDict): ...@@ -91,7 +91,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: torch.Tensor image_embeds: torch.Tensor
"""Supported types: """Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features. - list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features. Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features - `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors). (concatenation of all images' feature tensors).
...@@ -137,7 +137,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): ...@@ -137,7 +137,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"] type: Literal["video_embeds"]
video_embeds: torch.Tensor video_embeds: torch.Tensor
"""Supported types: """Supported types:
- List[`torch.Tensor`]: A list of tensors holding all videos' features. - list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features. Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features - `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors). (concatenation of all videos' feature tensors).
...@@ -709,8 +709,8 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -709,8 +709,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states = hidden_states[reverse_indices, :] hidden_states = hidden_states[reverse_indices, :]
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"), ("attn.qkv.", "attn.q.", "q"),
...@@ -718,7 +718,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -718,7 +718,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
("attn.qkv.", "attn.v.", "v"), ("attn.qkv.", "attn.v.", "v"),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
...@@ -750,7 +750,7 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -750,7 +750,7 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None, size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, List[float]]] = None, fps: Optional[Union[float, list[float]]] = None,
**kwargs: object, **kwargs: object,
) -> Qwen2_5_VLProcessor: ) -> Qwen2_5_VLProcessor:
if fps is not None: if fps is not None:
...@@ -1116,8 +1116,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1116,8 +1116,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Set, Tuple, TypedDict, Union from typing import Any, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -403,7 +403,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -403,7 +403,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -169,12 +170,12 @@ class Qwen2MoeAttention(nn.Module): ...@@ -169,12 +170,12 @@ class Qwen2MoeAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
dual_chunk_attention_config: Optional[Dict[str, Any]] = None, dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -389,8 +390,8 @@ class Qwen2MoeModel(nn.Module): ...@@ -389,8 +390,8 @@ class Qwen2MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -409,7 +410,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -409,7 +410,7 @@ class Qwen2MoeModel(nn.Module):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
...@@ -532,8 +533,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -532,8 +533,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["rotary_emb.inv_freq"]), skip_prefixes=(["rotary_emb.inv_freq"]),
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -95,8 +96,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP, ...@@ -95,8 +96,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."]) ignore_unexpected_prefixes=["lm_head."])
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -25,8 +25,7 @@ ...@@ -25,8 +25,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import Any, Callable, Literal, Optional, TypedDict, Union
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -102,7 +101,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict): ...@@ -102,7 +101,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
image_embeds: torch.Tensor image_embeds: torch.Tensor
"""Supported types: """Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features. - list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features. Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features - `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors). (concatenation of all images' feature tensors).
...@@ -142,7 +141,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict): ...@@ -142,7 +141,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"] type: Literal["video_embeds"]
video_embeds: torch.Tensor video_embeds: torch.Tensor
"""Supported types: """Supported types:
- List[`torch.Tensor`]: A list of tensors holding all videos' features. - list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features. Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features - `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors). (concatenation of all videos' feature tensors).
...@@ -662,8 +661,8 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -662,8 +661,8 @@ class Qwen2VisionTransformer(nn.Module):
return x return x
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -671,7 +670,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -671,7 +670,7 @@ class Qwen2VisionTransformer(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
...@@ -1394,8 +1393,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1394,8 +1393,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights.""" """Inference-only Qwen3 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -63,7 +64,7 @@ class Qwen3Attention(nn.Module): ...@@ -63,7 +64,7 @@ class Qwen3Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None, rope_scaling: Optional[tuple] = None,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None: attn_type: str = AttentionType.DECODER) -> None:
super().__init__() super().__init__()
...@@ -201,7 +202,7 @@ class Qwen3DecoderLayer(nn.Module): ...@@ -201,7 +202,7 @@ class Qwen3DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -309,8 +310,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -309,8 +310,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -149,7 +150,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -149,7 +150,7 @@ class Qwen3MoeAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
head_dim: Optional[int] = None, head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06, rms_norm_eps: float = 1e-06,
...@@ -373,8 +374,8 @@ class Qwen3MoeModel(nn.Module): ...@@ -373,8 +374,8 @@ class Qwen3MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -393,7 +394,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -393,7 +394,7 @@ class Qwen3MoeModel(nn.Module):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
...@@ -527,8 +528,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): ...@@ -527,8 +528,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["rotary_emb.inv_freq"]), skip_prefixes=(["rotary_emb.inv_freq"]),
......
...@@ -9,10 +9,9 @@ import copy ...@@ -9,10 +9,9 @@ import copy
import math import math
import re import re
import unicodedata import unicodedata
from collections.abc import Collection, Mapping, Sequence from collections.abc import Collection, Mapping, Sequence, Set
from collections.abc import Set as AbstractSet
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Callable, List, Literal, Optional, TypedDict, Union from typing import Callable, Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -395,7 +394,7 @@ def _get_tokenizer_without_image_pad( ...@@ -395,7 +394,7 @@ def _get_tokenizer_without_image_pad(
def tokenize( def tokenize(
self, self,
text: str, text: str,
allowed_special: Union[AbstractSet[str], str] = "all", allowed_special: Union[Set[str], str] = "all",
disallowed_special: Union[Collection[str], str] = (), disallowed_special: Union[Collection[str], str] = (),
**kwargs, **kwargs,
) -> list[Union[bytes, str]]: ) -> list[Union[bytes, str]]:
...@@ -411,7 +410,7 @@ def _get_tokenizer_without_image_pad( ...@@ -411,7 +410,7 @@ def _get_tokenizer_without_image_pad(
def _decode( def _decode(
self, self,
token_ids: Union[int, List[int]], token_ids: Union[int, list[int]],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
errors: Optional[str] = None, errors: Optional[str] = None,
**kwargs, **kwargs,
......
...@@ -10,10 +10,10 @@ import subprocess ...@@ -10,10 +10,10 @@ import subprocess
import sys import sys
import tempfile import tempfile
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type, from typing import Callable, Optional, TypeVar, Union
TypeVar, Union)
import cloudpickle import cloudpickle
import torch.nn as nn import torch.nn as nn
...@@ -266,7 +266,7 @@ class _ModelInfo: ...@@ -266,7 +266,7 @@ class _ModelInfo:
supports_v0_only: bool supports_v0_only: bool
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
return _ModelInfo( return _ModelInfo(
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
...@@ -290,7 +290,7 @@ class _BaseRegisteredModel(ABC): ...@@ -290,7 +290,7 @@ class _BaseRegisteredModel(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load_model_cls(self) -> Type[nn.Module]: def load_model_cls(self) -> type[nn.Module]:
raise NotImplementedError raise NotImplementedError
...@@ -301,10 +301,10 @@ class _RegisteredModel(_BaseRegisteredModel): ...@@ -301,10 +301,10 @@ class _RegisteredModel(_BaseRegisteredModel):
""" """
interfaces: _ModelInfo interfaces: _ModelInfo
model_cls: Type[nn.Module] model_cls: type[nn.Module]
@staticmethod @staticmethod
def from_model_cls(model_cls: Type[nn.Module]): def from_model_cls(model_cls: type[nn.Module]):
return _RegisteredModel( return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls), interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls, model_cls=model_cls,
...@@ -313,7 +313,7 @@ class _RegisteredModel(_BaseRegisteredModel): ...@@ -313,7 +313,7 @@ class _RegisteredModel(_BaseRegisteredModel):
def inspect_model_cls(self) -> _ModelInfo: def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces return self.interfaces
def load_model_cls(self) -> Type[nn.Module]: def load_model_cls(self) -> type[nn.Module]:
return self.model_cls return self.model_cls
...@@ -330,7 +330,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel): ...@@ -330,7 +330,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
return _run_in_subprocess( return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls())) lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]: def load_model_cls(self) -> type[nn.Module]:
mod = importlib.import_module(self.module_name) mod = importlib.import_module(self.module_name)
return getattr(mod, self.class_name) return getattr(mod, self.class_name)
...@@ -339,7 +339,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel): ...@@ -339,7 +339,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
def _try_load_model_cls( def _try_load_model_cls(
model_arch: str, model_arch: str,
model: _BaseRegisteredModel, model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]: ) -> Optional[type[nn.Module]]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch) current_platform.verify_model_arch(model_arch)
try: try:
...@@ -366,15 +366,15 @@ def _try_inspect_model_cls( ...@@ -366,15 +366,15 @@ def _try_inspect_model_cls(
@dataclass @dataclass
class _ModelRegistry: class _ModelRegistry:
# Keyed by model_arch # Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> AbstractSet[str]: def get_supported_archs(self) -> Set[str]:
return self.models.keys() return self.models.keys()
def register_model( def register_model(
self, self,
model_arch: str, model_arch: str,
model_cls: Union[Type[nn.Module], str], model_cls: Union[type[nn.Module], str],
) -> None: ) -> None:
""" """
Register an external model to be used in vLLM. Register an external model to be used in vLLM.
...@@ -413,7 +413,7 @@ class _ModelRegistry: ...@@ -413,7 +413,7 @@ class _ModelRegistry:
self.models[model_arch] = model self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]): def _raise_for_unsupported(self, architectures: list[str]):
all_supported_archs = self.get_supported_archs() all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures): if any(arch in all_supported_archs for arch in architectures):
...@@ -426,7 +426,7 @@ class _ModelRegistry: ...@@ -426,7 +426,7 @@ class _ModelRegistry:
f"Supported architectures: {all_supported_archs}") f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self, def _try_load_model_cls(self,
model_arch: str) -> Optional[Type[nn.Module]]: model_arch: str) -> Optional[type[nn.Module]]:
if model_arch not in self.models: if model_arch not in self.models:
return None return None
...@@ -440,8 +440,8 @@ class _ModelRegistry: ...@@ -440,8 +440,8 @@ class _ModelRegistry:
def _normalize_archs( def _normalize_archs(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> List[str]: ) -> list[str]:
if isinstance(architectures, str): if isinstance(architectures, str):
architectures = [architectures] architectures = [architectures]
if not architectures: if not architectures:
...@@ -458,8 +458,8 @@ class _ModelRegistry: ...@@ -458,8 +458,8 @@ class _ModelRegistry:
def inspect_model_cls( def inspect_model_cls(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> Tuple[_ModelInfo, str]: ) -> tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures) architectures = self._normalize_archs(architectures)
for arch in architectures: for arch in architectures:
...@@ -471,8 +471,8 @@ class _ModelRegistry: ...@@ -471,8 +471,8 @@ class _ModelRegistry:
def resolve_model_cls( def resolve_model_cls(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> Tuple[Type[nn.Module], str]: ) -> tuple[type[nn.Module], str]:
architectures = self._normalize_archs(architectures) architectures = self._normalize_archs(architectures)
for arch in architectures: for arch in architectures:
...@@ -484,77 +484,77 @@ class _ModelRegistry: ...@@ -484,77 +484,77 @@ class _ModelRegistry:
def is_text_generation_model( def is_text_generation_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model return model_cls.is_text_generation_model
def is_pooling_model( def is_pooling_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_pooling_model return model_cls.is_pooling_model
def is_cross_encoder_model( def is_cross_encoder_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding return model_cls.supports_cross_encoding
def is_multimodal_model( def is_multimodal_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal return model_cls.supports_multimodal
def is_pp_supported_model( def is_pp_supported_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp return model_cls.supports_pp
def model_has_inner_state( def model_has_inner_state(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state return model_cls.has_inner_state
def is_attention_free_model( def is_attention_free_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free return model_cls.is_attention_free
def is_hybrid_model( def is_hybrid_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid return model_cls.is_hybrid
def is_noops_model( def is_noops_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_noops return model_cls.has_noops
def is_transcription_model( def is_transcription_model(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription return model_cls.supports_transcription
def is_v1_compatible( def is_v1_compatible(
self, self,
architectures: Union[str, List[str]], architectures: Union[str, list[str]],
) -> bool: ) -> bool:
model_cls, _ = self.inspect_model_cls(architectures) model_cls, _ = self.inspect_model_cls(architectures)
return not model_cls.supports_v0_only return not model_cls.supports_v0_only
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Iterable, Optional, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -135,7 +136,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -135,7 +136,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
prefix=prefix, prefix=prefix,
embedding_class=RobertaEmbedding) embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory). # Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base. # For use with models like FacebookAI/roberta-base.
...@@ -187,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -187,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
self._pooler = CrossEncodingPooler(config, self.classifier) self._pooler = CrossEncodingPooler(config, self.classifier)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights) bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
...@@ -249,8 +250,8 @@ def create_position_ids_from_input_ids(input_ids, ...@@ -249,8 +250,8 @@ def create_position_ids_from_input_ids(input_ids,
def roberta_task_weights_filter( def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]] all_weights: Iterable[tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str, ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
torch.Tensor]]]: torch.Tensor]]]:
""" """
Separate task-specific weights that are applied on top Separate task-specific weights that are applied on top
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
within a vision language model.""" within a vision language model."""
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -265,7 +266,7 @@ class SiglipEncoderLayer(nn.Module): ...@@ -265,7 +266,7 @@ class SiglipEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, None]: ) -> tuple[torch.Tensor, None]:
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
...@@ -480,8 +481,8 @@ class SiglipVisionModel(nn.Module): ...@@ -480,8 +481,8 @@ class SiglipVisionModel(nn.Module):
feature_sample_layers=feature_sample_layers, feature_sample_layers=feature_sample_layers,
) )
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -489,7 +490,7 @@ class SiglipVisionModel(nn.Module): ...@@ -489,7 +490,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union from typing import Literal, Optional, TypedDict, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -937,8 +937,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -937,8 +937,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [ skip_prefixes = [
"action_embed", "temporal_embed", "track_embed", "action_embed", "temporal_embed", "track_embed",
"track_embed_decoder", "box_token", "cg_criterion", "cg_model", "track_embed_decoder", "box_token", "cg_criterion", "cg_model",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment