Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
......@@ -5,7 +5,7 @@
# but implemented by the Phi-Speech team
#!/usr/bin/env python3
import math
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn.functional as F
......@@ -1586,7 +1586,7 @@ class AttModule(nn.Module):
memory: Optional[Tensor] = None,
pos_emb: Optional[Tensor] = None,
att_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
"""AttModule forward
Args:
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only PhiMoE model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -505,8 +506,8 @@ class PhiMoEModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -521,7 +522,7 @@ class PhiMoEModel(nn.Module):
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
......@@ -657,8 +658,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
......
......@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -438,18 +438,18 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]):
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_encoder")
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter")
def is_patch_merger(weight: Tuple[str, torch.Tensor]):
def is_patch_merger(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("patch_merger")
def is_pre_mm_projector_norm(weight: Tuple[str, torch.Tensor]):
def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
return weight[0].startswith("pre_mm_projector_norm")
# Get references to parameters for direct loading
......@@ -566,7 +566,7 @@ def apply_rotary_emb_vit(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.dtype == torch.complex64
......@@ -671,7 +671,7 @@ class Transformer(nn.Module):
return x
def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor:
def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor:
positions = torch.cat([
torch.stack(
torch.meshgrid(
......@@ -733,7 +733,7 @@ class VisionTransformer(nn.Module):
def forward(
self,
images: List[torch.Tensor],
images: list[torch.Tensor],
) -> torch.Tensor:
"""
Args:
......@@ -1023,7 +1023,7 @@ class PixtralHFAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
batch, patches, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
......@@ -1249,8 +1249,8 @@ class PixtralHFVisionModel(nn.Module):
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -1260,7 +1260,7 @@ class PixtralHFVisionModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
layer_count = len(self.transformer.layers)
for name, loaded_weight in weights:
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only PLaMo2 model."""
import math
from typing import Iterable, Optional, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -659,7 +660,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = (self.config.mamba_num_heads *
self.config.hidden_size_per_head)
......@@ -682,7 +683,7 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel, HasInnerState, IsHybrid,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
......
......@@ -16,7 +16,7 @@
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -154,7 +154,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
"by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
......@@ -195,8 +195,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_list = []
model_buffers = dict(self.named_buffers())
loaded_buffers = []
......
......@@ -6,7 +6,8 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
import json
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -76,7 +77,7 @@ class QWenAttention(nn.Module):
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
......@@ -166,7 +167,7 @@ class QWenBlock(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -284,15 +285,15 @@ class QWenBaseModel(nn.Module):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -108,7 +109,7 @@ class Qwen2Attention(nn.Module):
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
......@@ -245,7 +246,7 @@ class Qwen2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -367,8 +368,8 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -378,7 +379,7 @@ class Qwen2Model(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -490,8 +491,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......@@ -559,7 +560,7 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
......
......@@ -21,10 +21,10 @@
# limitations under the License.
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence,
Set, Tuple, Union)
from typing import Any, Optional, Union
import torch
import torch.nn as nn
......@@ -138,7 +138,7 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, List[float]]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5OmniProcessor:
if fps is not None:
......@@ -550,7 +550,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_image_input(
self,
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
......@@ -589,7 +589,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_video_input(
self,
**kwargs: Dict[str, Any],
**kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLVideoInputs]:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
......@@ -627,7 +627,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_audio_input(
self,
audio_input: Qwen2AudioInputs,
audio_hashes: List[str] = None,
audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
......@@ -676,7 +676,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs,
video_hashes: List[str] = None,
video_hashes: list[str] = None,
cached_video_embeds: torch.Tensor = None) -> torch.Tensor:
if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype)
......@@ -825,7 +825,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
if audio_input is None and image_input is None and video_input is None:
return None
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
multimodal_embeddings: list[tuple[NestedTensors, str]] = []
if audio_input is not None:
audio_embeds = self._process_audio_input(audio_input)
......@@ -891,8 +891,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["talker.", "token2wav."],
......
......@@ -24,9 +24,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping
from functools import partial
from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import Callable, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -91,7 +91,7 @@ class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
......@@ -137,7 +137,7 @@ class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
......@@ -709,8 +709,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
......@@ -718,7 +718,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......@@ -750,7 +750,7 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, List[float]]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5_VLProcessor:
if fps is not None:
......@@ -1116,8 +1116,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
......@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -403,7 +403,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -23,7 +23,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
......@@ -169,12 +170,12 @@ class Qwen2MoeAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -389,8 +390,8 @@ class Qwen2MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -409,7 +410,7 @@ class Qwen2MoeModel(nn.Module):
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
......@@ -532,8 +533,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
......
......@@ -5,7 +5,8 @@
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -95,8 +96,8 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
return loader.load_weights(weights)
......
......@@ -25,8 +25,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
from typing import Any, Callable, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -102,7 +101,7 @@ class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
- list[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
......@@ -142,7 +141,7 @@ class Qwen2VLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
video_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all videos' features.
- list[`torch.Tensor`]: A list of tensors holding all videos' features.
Each tensor holds an video's features.
- `torch.Tensor`: A tensor holding all videos' features
(concatenation of all videos' feature tensors).
......@@ -662,8 +661,8 @@ class Qwen2VisionTransformer(nn.Module):
return x
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -671,7 +670,7 @@ class Qwen2VisionTransformer(nn.Module):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......@@ -1394,8 +1393,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
......@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -63,7 +64,7 @@ class Qwen3Attention(nn.Module):
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
......@@ -201,7 +202,7 @@ class Qwen3DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -309,8 +310,8 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -149,7 +150,7 @@ class Qwen3MoeAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
......@@ -373,8 +374,8 @@ class Qwen3MoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -393,7 +394,7 @@ class Qwen3MoeModel(nn.Module):
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
......@@ -527,8 +528,8 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["rotary_emb.inv_freq"]),
......
......@@ -9,10 +9,9 @@ import copy
import math
import re
import unicodedata
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Set as AbstractSet
from collections.abc import Collection, Mapping, Sequence, Set
from functools import lru_cache, partial
from typing import Callable, List, Literal, Optional, TypedDict, Union
from typing import Callable, Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -395,7 +394,7 @@ def _get_tokenizer_without_image_pad(
def tokenize(
self,
text: str,
allowed_special: Union[AbstractSet[str], str] = "all",
allowed_special: Union[Set[str], str] = "all",
disallowed_special: Union[Collection[str], str] = (),
**kwargs,
) -> list[Union[bytes, str]]:
......@@ -411,7 +410,7 @@ def _get_tokenizer_without_image_pad(
def _decode(
self,
token_ids: Union[int, List[int]],
token_ids: Union[int, list[int]],
skip_special_tokens: bool = False,
errors: Optional[str] = None,
**kwargs,
......
......@@ -10,10 +10,10 @@ import subprocess
import sys
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from functools import lru_cache
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
from typing import Callable, Optional, TypeVar, Union
import cloudpickle
import torch.nn as nn
......@@ -266,7 +266,7 @@ class _ModelInfo:
supports_v0_only: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
......@@ -290,7 +290,7 @@ class _BaseRegisteredModel(ABC):
raise NotImplementedError
@abstractmethod
def load_model_cls(self) -> Type[nn.Module]:
def load_model_cls(self) -> type[nn.Module]:
raise NotImplementedError
......@@ -301,10 +301,10 @@ class _RegisteredModel(_BaseRegisteredModel):
"""
interfaces: _ModelInfo
model_cls: Type[nn.Module]
model_cls: type[nn.Module]
@staticmethod
def from_model_cls(model_cls: Type[nn.Module]):
def from_model_cls(model_cls: type[nn.Module]):
return _RegisteredModel(
interfaces=_ModelInfo.from_model_cls(model_cls),
model_cls=model_cls,
......@@ -313,7 +313,7 @@ class _RegisteredModel(_BaseRegisteredModel):
def inspect_model_cls(self) -> _ModelInfo:
return self.interfaces
def load_model_cls(self) -> Type[nn.Module]:
def load_model_cls(self) -> type[nn.Module]:
return self.model_cls
......@@ -330,7 +330,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
return _run_in_subprocess(
lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
def load_model_cls(self) -> Type[nn.Module]:
def load_model_cls(self) -> type[nn.Module]:
mod = importlib.import_module(self.module_name)
return getattr(mod, self.class_name)
......@@ -339,7 +339,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
def _try_load_model_cls(
model_arch: str,
model: _BaseRegisteredModel,
) -> Optional[Type[nn.Module]]:
) -> Optional[type[nn.Module]]:
from vllm.platforms import current_platform
current_platform.verify_model_arch(model_arch)
try:
......@@ -366,15 +366,15 @@ def _try_inspect_model_cls(
@dataclass
class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
models: dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> AbstractSet[str]:
def get_supported_archs(self) -> Set[str]:
return self.models.keys()
def register_model(
self,
model_arch: str,
model_cls: Union[Type[nn.Module], str],
model_cls: Union[type[nn.Module], str],
) -> None:
"""
Register an external model to be used in vLLM.
......@@ -413,7 +413,7 @@ class _ModelRegistry:
self.models[model_arch] = model
def _raise_for_unsupported(self, architectures: List[str]):
def _raise_for_unsupported(self, architectures: list[str]):
all_supported_archs = self.get_supported_archs()
if any(arch in all_supported_archs for arch in architectures):
......@@ -426,7 +426,7 @@ class _ModelRegistry:
f"Supported architectures: {all_supported_archs}")
def _try_load_model_cls(self,
model_arch: str) -> Optional[Type[nn.Module]]:
model_arch: str) -> Optional[type[nn.Module]]:
if model_arch not in self.models:
return None
......@@ -440,8 +440,8 @@ class _ModelRegistry:
def _normalize_archs(
self,
architectures: Union[str, List[str]],
) -> List[str]:
architectures: Union[str, list[str]],
) -> list[str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
......@@ -458,8 +458,8 @@ class _ModelRegistry:
def inspect_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[_ModelInfo, str]:
architectures: Union[str, list[str]],
) -> tuple[_ModelInfo, str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
......@@ -471,8 +471,8 @@ class _ModelRegistry:
def resolve_model_cls(
self,
architectures: Union[str, List[str]],
) -> Tuple[Type[nn.Module], str]:
architectures: Union[str, list[str]],
) -> tuple[type[nn.Module], str]:
architectures = self._normalize_archs(architectures)
for arch in architectures:
......@@ -484,77 +484,77 @@ class _ModelRegistry:
def is_text_generation_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_text_generation_model
def is_pooling_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_pooling_model
def is_cross_encoder_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_cross_encoding
def is_multimodal_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal
def is_pp_supported_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_pp
def model_has_inner_state(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_inner_state
def is_attention_free_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_attention_free
def is_hybrid_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid
def is_noops_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_noops
def is_transcription_model(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription
def is_v1_compatible(
self,
architectures: Union[str, List[str]],
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return not model_cls.supports_v0_only
......
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import Iterable, Optional, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -135,7 +136,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
prefix=prefix,
embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
......@@ -187,7 +188,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.classifier = RobertaClassificationHead(config)
self._pooler = CrossEncodingPooler(config, self.classifier)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
......@@ -249,8 +250,8 @@ def create_position_ids_from_input_ids(input_ids,
def roberta_task_weights_filter(
all_weights: Iterable[Tuple[str, torch.Tensor]]
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[str,
all_weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
......
......@@ -3,7 +3,8 @@
within a vision language model."""
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -265,7 +266,7 @@ class SiglipEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
) -> tuple[torch.Tensor, None]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
......@@ -480,8 +481,8 @@ class SiglipVisionModel(nn.Module):
feature_sample_layers=feature_sample_layers,
)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -489,7 +490,7 @@ class SiglipVisionModel(nn.Module):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
......
......@@ -8,7 +8,7 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
from typing import Literal, Optional, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
......@@ -937,8 +937,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = [
"action_embed", "temporal_embed", "track_embed",
"track_embed_decoder", "box_token", "cg_criterion", "cg_model",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment