Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -212,11 +213,11 @@ class ModernBertModel(nn.Module):
eps=config.norm_eps,
bias=config.norm_bias)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -280,7 +281,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self._pooler = CrossEncodingPooler(config, self.classifier,
ModernBertPooler(config))
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
......
......@@ -4,7 +4,7 @@
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from dataclasses import dataclass, field
from typing import List, Union
from typing import Union
@dataclass
......@@ -46,17 +46,17 @@ class ModelKeys:
@dataclass
class MultiModelKeys(ModelKeys):
language_model: List[str] = field(default_factory=list)
connector: List[str] = field(default_factory=list)
language_model: list[str] = field(default_factory=list)
connector: list[str] = field(default_factory=list)
# vision tower and audio tower
tower_model: List[str] = field(default_factory=list)
generator: List[str] = field(default_factory=list)
tower_model: list[str] = field(default_factory=list)
generator: list[str] = field(default_factory=list)
@staticmethod
def from_string_field(language_model: Union[str, List[str]] = None,
connector: Union[str, List[str]] = None,
tower_model: Union[str, List[str]] = None,
generator: Union[str, List[str]] = None,
def from_string_field(language_model: Union[str, list[str]] = None,
connector: Union[str, list[str]] = None,
tower_model: Union[str, list[str]] = None,
generator: Union[str, list[str]] = None,
**kwargs) -> 'MultiModelKeys':
def to_list(value):
......
......@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import List, Optional, Set, Tuple, TypedDict, Union
from typing import Optional, TypedDict, Union
import numpy as np
import torch
......@@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict):
@dataclass
class VisionBackboneConfig:
image_default_input_size: Tuple[int, int] = (336, 336)
image_default_input_size: tuple[int, int] = (336, 336)
image_patch_size: int = 14
image_pos_patch_size: int = 14
image_emb_dim: int = 1024
......@@ -267,7 +267,7 @@ class BlockCollection(nn.Module):
for _ in range(config.image_num_layers)
])
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
......@@ -334,7 +334,7 @@ class VisionTransformer(nn.Module):
def forward(self,
x: torch.Tensor,
patch_num: Optional[int] = None) -> List[torch.Tensor]:
patch_num: Optional[int] = None) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
......@@ -434,7 +434,7 @@ class MolmoAttention(nn.Module):
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
......@@ -570,7 +570,7 @@ class MolmoDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -596,7 +596,7 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
residual = hidden_states
hidden_states = self.self_attn(
......@@ -740,15 +740,15 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
# image_features: (batch_size, num_image, num_patch, d_model)
return image_features
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......@@ -855,10 +855,10 @@ class MolmoModel(nn.Module, SupportsQuant):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
......@@ -1530,7 +1530,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
......@@ -1548,8 +1548,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name:
......
......@@ -42,9 +42,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from collections.abc import Sequence
from copy import deepcopy
from functools import cached_property
from typing import List, Optional, Sequence, Tuple, Union
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -222,7 +223,7 @@ class MoonVisionPatchEmbed(nn.Module):
self,
out_dim: int,
in_dim: int = 3,
patch_size: Union[int, Tuple[int, int]] = (14, 14),
patch_size: Union[int, tuple[int, int]] = (14, 14),
pos_emb_height: int = 14,
pos_emb_width: int = 14,
):
......@@ -526,7 +527,7 @@ def patch_merger(
x: torch.Tensor,
grid_hw: torch.Tensor,
merge_kernel_size: list[int, int] = (2, 2),
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
d_model = x.size(-1)
outputs = []
......
......@@ -2,7 +2,8 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -265,10 +266,10 @@ class MPTModel(nn.Module):
hidden_states = self.norm_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......@@ -323,7 +324,7 @@ class MPTForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -69,7 +70,7 @@ def _cast_if_autocast_enabled(*args):
class NemotronLayerNorm1P(nn.LayerNorm):
def __init__(self,
normalized_shape: Union[int, List[int], torch.Size],
normalized_shape: Union[int, list[int], torch.Size],
eps: float = 1e-5,
elementwise_affine: bool = True,
bias: bool = True,
......@@ -133,7 +134,7 @@ class NemotronAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -267,7 +268,7 @@ class NemotronDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -441,8 +442,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -450,7 +451,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Type, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -135,7 +136,7 @@ class DeciLMDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if self._is_no_op_attention:
......@@ -168,7 +169,7 @@ class DeciModel(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
layer_type: type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
):
super().__init__()
......@@ -260,8 +261,8 @@ class DeciModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -271,7 +272,7 @@ class DeciModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -428,8 +429,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -209,7 +210,7 @@ class OlmoDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Attention block.
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -338,8 +339,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -349,7 +350,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -23,8 +23,9 @@
# limitations under the License.
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import partial
from typing import Iterable, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
......@@ -135,7 +136,7 @@ class Olmo2Attention(nn.Module):
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
......@@ -365,7 +366,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......
......@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -102,7 +103,7 @@ class OlmoeAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -307,8 +308,8 @@ class OlmoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -327,7 +328,7 @@ class OlmoeModel(nn.Module):
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
......@@ -439,8 +440,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=["rotary_emb.inv_freq"],
......
......@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -312,8 +313,8 @@ class OPTModel(nn.Module):
intermediate_tensors,
inputs_embeds=inputs_embeds)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -321,7 +322,7 @@ class OPTModel(nn.Module):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -400,8 +401,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head.weight"]
......
......@@ -5,7 +5,8 @@
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -72,7 +73,7 @@ class OrionAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -186,7 +187,7 @@ class OrionDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -259,8 +260,8 @@ class OrionModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -270,7 +271,7 @@ class OrionModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -341,8 +342,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
......
......@@ -17,8 +17,8 @@
# limitations under the License.
""" PyTorch Ovis model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -211,7 +211,7 @@ class OvisImagePatchInputs(TypedDict):
`(batch_size * (num_patches + 1))`
"""
patches_per_image: List[int]
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
......@@ -545,8 +545,8 @@ class Ovis(nn.Module, SupportsMultiModal):
logits = self.llm.compute_logits(hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -391,7 +391,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -260,10 +261,10 @@ class PersimmonModel(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if is_pp_missing_parameter(name, self):
continue
......@@ -336,7 +337,7 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -36,7 +36,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -248,8 +249,8 @@ class PhiModel(nn.Module):
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -257,7 +258,7 @@ class PhiModel(nn.Module):
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
......@@ -348,7 +349,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata, self.lm_head.bias)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -230,8 +231,8 @@ class Phi3SmallSelfAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[tuple[torch.Tensor]]]:
qkv, _ = self.query_key_value(hidden_states)
qkv = qkv.view(qkv.shape[:-1] +
......@@ -352,10 +353,10 @@ class Phi3SmallModel(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
......@@ -454,8 +455,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
output_hidden_states = output_hidden_states
return output_hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head.weight"]
......
......@@ -16,7 +16,7 @@
# limitations under the License.
import re
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -94,7 +94,7 @@ def _init_img_processor(hf_config: PretrainedConfig,
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
......@@ -113,7 +113,7 @@ class Phi3VImagePixelInputs(TypedDict):
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
......@@ -571,8 +571,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w)
......@@ -707,8 +707,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights,
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import numpy as np
import torch
......@@ -392,7 +392,7 @@ class Phi4MMImageEncoder(nn.Module):
class Phi4MMImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
......@@ -417,7 +417,7 @@ class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
......@@ -426,7 +426,7 @@ class Phi4MMImageEmbeddingInputs(TypedDict):
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
......@@ -1031,7 +1031,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return audio_embeds
def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[Dict]:
**kwargs: object) -> Optional[dict]:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if input_image_embeds is None:
return None
......@@ -1238,7 +1238,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
......
......@@ -6,7 +6,7 @@
#!/usr/bin/env python3
import abc
import math
from typing import List, Literal, Optional
from typing import Literal, Optional
import numpy as np
import torch
......@@ -746,7 +746,7 @@ class ConformerEncoder(TransformerEncoderBase):
attention_group_size = attenion_heads = Multi-Query Attention
"""
extra_multi_layer_output_idxs: List[int]
extra_multi_layer_output_idxs: list[int]
def __init__( # pylint: disable-all
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment