Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -212,11 +213,11 @@ class ModernBertModel(nn.Module): ...@@ -212,11 +213,11 @@ class ModernBertModel(nn.Module):
eps=config.norm_eps, eps=config.norm_eps,
bias=config.norm_bias) bias=config.norm_bias)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -280,7 +281,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -280,7 +281,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self._pooler = CrossEncodingPooler(config, self.classifier, self._pooler = CrossEncodingPooler(config, self.classifier,
ModernBertPooler(config)) ModernBertPooler(config))
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = [] self_weights = []
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py # https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Union from typing import Union
@dataclass @dataclass
...@@ -46,17 +46,17 @@ class ModelKeys: ...@@ -46,17 +46,17 @@ class ModelKeys:
@dataclass @dataclass
class MultiModelKeys(ModelKeys): class MultiModelKeys(ModelKeys):
language_model: List[str] = field(default_factory=list) language_model: list[str] = field(default_factory=list)
connector: List[str] = field(default_factory=list) connector: list[str] = field(default_factory=list)
# vision tower and audio tower # vision tower and audio tower
tower_model: List[str] = field(default_factory=list) tower_model: list[str] = field(default_factory=list)
generator: List[str] = field(default_factory=list) generator: list[str] = field(default_factory=list)
@staticmethod @staticmethod
def from_string_field(language_model: Union[str, List[str]] = None, def from_string_field(language_model: Union[str, list[str]] = None,
connector: Union[str, List[str]] = None, connector: Union[str, list[str]] = None,
tower_model: Union[str, List[str]] = None, tower_model: Union[str, list[str]] = None,
generator: Union[str, List[str]] = None, generator: Union[str, list[str]] = None,
**kwargs) -> 'MultiModelKeys': **kwargs) -> 'MultiModelKeys':
def to_list(value): def to_list(value):
......
...@@ -4,7 +4,7 @@ import math ...@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, partial from functools import cached_property, partial
from typing import List, Optional, Set, Tuple, TypedDict, Union from typing import Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict): ...@@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict):
@dataclass @dataclass
class VisionBackboneConfig: class VisionBackboneConfig:
image_default_input_size: Tuple[int, int] = (336, 336) image_default_input_size: tuple[int, int] = (336, 336)
image_patch_size: int = 14 image_patch_size: int = 14
image_pos_patch_size: int = 14 image_pos_patch_size: int = 14
image_emb_dim: int = 1024 image_emb_dim: int = 1024
...@@ -267,7 +267,7 @@ class BlockCollection(nn.Module): ...@@ -267,7 +267,7 @@ class BlockCollection(nn.Module):
for _ in range(config.image_num_layers) for _ in range(config.image_num_layers)
]) ])
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = [] hidden_states = []
for r in self.resblocks: for r in self.resblocks:
x = r(x) x = r(x)
...@@ -334,7 +334,7 @@ class VisionTransformer(nn.Module): ...@@ -334,7 +334,7 @@ class VisionTransformer(nn.Module):
def forward(self, def forward(self,
x: torch.Tensor, x: torch.Tensor,
patch_num: Optional[int] = None) -> List[torch.Tensor]: patch_num: Optional[int] = None) -> list[torch.Tensor]:
""" """
: param x: (batch_size, num_patch, n_pixels) : param x: (batch_size, num_patch, n_pixels)
""" """
...@@ -434,7 +434,7 @@ class MolmoAttention(nn.Module): ...@@ -434,7 +434,7 @@ class MolmoAttention(nn.Module):
) )
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1: if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous()) q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous())
...@@ -570,7 +570,7 @@ class MolmoDecoderLayer(nn.Module): ...@@ -570,7 +570,7 @@ class MolmoDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -596,7 +596,7 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): ...@@ -596,7 +596,7 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn( hidden_states = self.self_attn(
...@@ -740,15 +740,15 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant): ...@@ -740,15 +740,15 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
# image_features: (batch_size, num_image, num_patch, d_model) # image_features: (batch_size, num_image, num_patch, d_model)
return image_features return image_features
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("merged_linear", "gate_proj", 0), ("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1), ("merged_linear", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
...@@ -855,10 +855,10 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -855,10 +855,10 @@ class MolmoModel(nn.Module, SupportsQuant):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -1530,7 +1530,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1530,7 +1530,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights) weights = _get_weights_with_merged_embedding(weights)
...@@ -1548,8 +1548,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1548,8 +1548,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _get_weights_with_merged_embedding( def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]] weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]: ) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {} embedding_weights = {}
for name, weight in weights: for name, weight in weights:
if "wte.embedding" in name: if "wte.embedding" in name:
......
...@@ -42,9 +42,10 @@ ...@@ -42,9 +42,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import math import math
from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from functools import cached_property from functools import cached_property
from typing import List, Optional, Sequence, Tuple, Union from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -222,7 +223,7 @@ class MoonVisionPatchEmbed(nn.Module): ...@@ -222,7 +223,7 @@ class MoonVisionPatchEmbed(nn.Module):
self, self,
out_dim: int, out_dim: int,
in_dim: int = 3, in_dim: int = 3,
patch_size: Union[int, Tuple[int, int]] = (14, 14), patch_size: Union[int, tuple[int, int]] = (14, 14),
pos_emb_height: int = 14, pos_emb_height: int = 14,
pos_emb_width: int = 14, pos_emb_width: int = 14,
): ):
...@@ -526,7 +527,7 @@ def patch_merger( ...@@ -526,7 +527,7 @@ def patch_merger(
x: torch.Tensor, x: torch.Tensor,
grid_hw: torch.Tensor, grid_hw: torch.Tensor,
merge_kernel_size: list[int, int] = (2, 2), merge_kernel_size: list[int, int] = (2, 2),
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
d_model = x.size(-1) d_model = x.size(-1)
outputs = [] outputs = []
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -265,10 +266,10 @@ class MPTModel(nn.Module): ...@@ -265,10 +266,10 @@ class MPTModel(nn.Module):
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -323,7 +324,7 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -323,7 +324,7 @@ class MPTForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Nemotron model compatible with HuggingFace weights.""" """Inference-only Nemotron model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -69,7 +70,7 @@ def _cast_if_autocast_enabled(*args): ...@@ -69,7 +70,7 @@ def _cast_if_autocast_enabled(*args):
class NemotronLayerNorm1P(nn.LayerNorm): class NemotronLayerNorm1P(nn.LayerNorm):
def __init__(self, def __init__(self,
normalized_shape: Union[int, List[int], torch.Size], normalized_shape: Union[int, list[int], torch.Size],
eps: float = 1e-5, eps: float = 1e-5,
elementwise_affine: bool = True, elementwise_affine: bool = True,
bias: bool = True, bias: bool = True,
...@@ -133,7 +134,7 @@ class NemotronAttention(nn.Module): ...@@ -133,7 +134,7 @@ class NemotronAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -267,7 +268,7 @@ class NemotronDecoderLayer(nn.Module): ...@@ -267,7 +268,7 @@ class NemotronDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -441,8 +442,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -441,8 +442,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -450,7 +451,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -450,7 +451,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights.""" """Inference-only deci model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Type, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -135,7 +136,7 @@ class DeciLMDecoderLayer(nn.Module): ...@@ -135,7 +136,7 @@ class DeciLMDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if self._is_no_op_attention: if self._is_no_op_attention:
...@@ -168,7 +169,7 @@ class DeciModel(nn.Module): ...@@ -168,7 +169,7 @@ class DeciModel(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer, layer_type: type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
): ):
super().__init__() super().__init__()
...@@ -260,8 +261,8 @@ class DeciModel(nn.Module): ...@@ -260,8 +261,8 @@ class DeciModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -271,7 +272,7 @@ class DeciModel(nn.Module): ...@@ -271,7 +272,7 @@ class DeciModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -428,8 +429,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): ...@@ -428,8 +429,8 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -209,7 +210,7 @@ class OlmoDecoderLayer(nn.Module): ...@@ -209,7 +210,7 @@ class OlmoDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Attention block. # Attention block.
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -338,8 +339,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -338,8 +339,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -349,7 +350,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -349,7 +350,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -23,8 +23,9 @@ ...@@ -23,8 +23,9 @@
# limitations under the License. # limitations under the License.
"""Inference-only OLMo2 model compatible with HuggingFace weights.""" """Inference-only OLMo2 model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Iterable, Optional, Tuple, Union from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -135,7 +136,7 @@ class Olmo2Attention(nn.Module): ...@@ -135,7 +136,7 @@ class Olmo2Attention(nn.Module):
) )
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1: if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous()) q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous())
...@@ -365,7 +366,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): ...@@ -365,7 +366,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
......
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -102,7 +103,7 @@ class OlmoeAttention(nn.Module): ...@@ -102,7 +103,7 @@ class OlmoeAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -307,8 +308,8 @@ class OlmoeModel(nn.Module): ...@@ -307,8 +308,8 @@ class OlmoeModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -327,7 +328,7 @@ class OlmoeModel(nn.Module): ...@@ -327,7 +328,7 @@ class OlmoeModel(nn.Module):
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
...@@ -439,8 +440,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -439,8 +440,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=["rotary_emb.inv_freq"], skip_prefixes=["rotary_emb.inv_freq"],
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -312,8 +313,8 @@ class OPTModel(nn.Module): ...@@ -312,8 +313,8 @@ class OPTModel(nn.Module):
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -321,7 +322,7 @@ class OPTModel(nn.Module): ...@@ -321,7 +322,7 @@ class OPTModel(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -400,8 +401,8 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -400,8 +401,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head.weight"] skip_prefixes=(["lm_head.weight"]
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
# Copyright (c) OrionStar Inc. # Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights.""" """Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -72,7 +73,7 @@ class OrionAttention(nn.Module): ...@@ -72,7 +73,7 @@ class OrionAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -186,7 +187,7 @@ class OrionDecoderLayer(nn.Module): ...@@ -186,7 +187,7 @@ class OrionDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -259,8 +260,8 @@ class OrionModel(nn.Module): ...@@ -259,8 +260,8 @@ class OrionModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -270,7 +271,7 @@ class OrionModel(nn.Module): ...@@ -270,7 +271,7 @@ class OrionModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -341,8 +342,8 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -341,8 +342,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=([ skip_prefixes=([
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
# limitations under the License. # limitations under the License.
""" PyTorch Ovis model.""" """ PyTorch Ovis model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from collections.abc import Iterable, Mapping
TypedDict, Union) from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -211,7 +211,7 @@ class OvisImagePatchInputs(TypedDict): ...@@ -211,7 +211,7 @@ class OvisImagePatchInputs(TypedDict):
`(batch_size * (num_patches + 1))` `(batch_size * (num_patches + 1))`
""" """
patches_per_image: List[int] patches_per_image: list[int]
""" """
List of number of total patches for each image in the batch. List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`. This is used to restore the first two dimensions of `flat_data`.
...@@ -545,8 +545,8 @@ class Ovis(nn.Module, SupportsMultiModal): ...@@ -545,8 +545,8 @@ class Ovis(nn.Module, SupportsMultiModal):
logits = self.llm.compute_logits(hidden_states, sampling_metadata) logits = self.llm.compute_logits(hidden_states, sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -391,7 +391,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -391,7 +391,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights.""" """Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -260,10 +261,10 @@ class PersimmonModel(nn.Module): ...@@ -260,10 +261,10 @@ class PersimmonModel(nn.Module):
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
...@@ -336,7 +337,7 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): ...@@ -336,7 +337,7 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -36,7 +36,8 @@ ...@@ -36,7 +36,8 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -248,8 +249,8 @@ class PhiModel(nn.Module): ...@@ -248,8 +249,8 @@ class PhiModel(nn.Module):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -257,7 +258,7 @@ class PhiModel(nn.Module): ...@@ -257,7 +258,7 @@ class PhiModel(nn.Module):
("qkv_proj", "v_proj", "v") ("qkv_proj", "v_proj", "v")
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
...@@ -348,7 +349,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -348,7 +349,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -230,8 +231,8 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -230,8 +231,8 @@ class Phi3SmallSelfAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]: Optional[tuple[torch.Tensor]]]:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
qkv = qkv.view(qkv.shape[:-1] + qkv = qkv.view(qkv.shape[:-1] +
...@@ -352,10 +353,10 @@ class Phi3SmallModel(nn.Module): ...@@ -352,10 +353,10 @@ class Phi3SmallModel(nn.Module):
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -454,8 +455,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): ...@@ -454,8 +455,8 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
output_hidden_states = output_hidden_states output_hidden_states = output_hidden_states
return output_hidden_states return output_hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head.weight"] skip_prefixes=(["lm_head.weight"]
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
import re import re
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -94,7 +94,7 @@ def _init_img_processor(hf_config: PretrainedConfig, ...@@ -94,7 +94,7 @@ def _init_img_processor(hf_config: PretrainedConfig,
class Phi3VImagePixelInputs(TypedDict): class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)` `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
...@@ -113,7 +113,7 @@ class Phi3VImagePixelInputs(TypedDict): ...@@ -113,7 +113,7 @@ class Phi3VImagePixelInputs(TypedDict):
class Phi3VImageEmbeddingInputs(TypedDict): class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
...@@ -571,8 +571,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -571,8 +571,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return data return data
def _validate_pixel_values( def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -707,8 +707,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -707,8 +707,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, autoloaded_weights = loader.load_weights(weights,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -392,7 +392,7 @@ class Phi4MMImageEncoder(nn.Module): ...@@ -392,7 +392,7 @@ class Phi4MMImageEncoder(nn.Module):
class Phi4MMImagePixelInputs(TypedDict): class Phi4MMImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)` `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
...@@ -417,7 +417,7 @@ class Phi4MMImagePixelInputs(TypedDict): ...@@ -417,7 +417,7 @@ class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImageEmbeddingInputs(TypedDict): class Phi4MMImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
...@@ -426,7 +426,7 @@ class Phi4MMImageEmbeddingInputs(TypedDict): ...@@ -426,7 +426,7 @@ class Phi4MMImageEmbeddingInputs(TypedDict):
class Phi4MMAudioFeatureInputs(TypedDict): class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"] type: Literal["audio_features"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)""" """Shape: `(batch_size * num_audios, 80, M)"""
...@@ -1031,7 +1031,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1031,7 +1031,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return audio_embeds return audio_embeds
def _parse_and_validate_image_input(self, def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[Dict]: **kwargs: object) -> Optional[dict]:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds") input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if input_image_embeds is None: if input_image_embeds is None:
return None return None
...@@ -1238,7 +1238,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1238,7 +1238,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None: torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights weights = ((name, data) for name, data in weights
if "lora" not in name) if "lora" not in name)
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import abc import abc
import math import math
from typing import List, Literal, Optional from typing import Literal, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -746,7 +746,7 @@ class ConformerEncoder(TransformerEncoderBase): ...@@ -746,7 +746,7 @@ class ConformerEncoder(TransformerEncoderBase):
attention_group_size = attenion_heads = Multi-Query Attention attention_group_size = attenion_heads = Multi-Query Attention
""" """
extra_multi_layer_output_idxs: List[int] extra_multi_layer_output_idxs: list[int]
def __init__( # pylint: disable-all def __init__( # pylint: disable-all
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment