Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -200,7 +201,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -200,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
q_lora_rank: int, q_lora_rank: int,
kv_lora_rank: int, kv_lora_rank: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -352,7 +353,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -352,7 +353,7 @@ class DeepseekV2MLAAttention(nn.Module):
q_lora_rank: Optional[int], q_lora_rank: Optional[int],
kv_lora_rank: int, kv_lora_rank: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -736,8 +737,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -736,8 +737,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
...@@ -753,7 +754,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -753,7 +754,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" """Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -45,7 +45,7 @@ _IMAGE_TOKEN = "<image>" ...@@ -45,7 +45,7 @@ _IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TypedDict): class DeepseekVL2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_images, num_channels, height, width)` Shape: `(batch_size * num_images, num_channels, height, width)`
""" """
...@@ -57,7 +57,7 @@ class DeepseekVL2ImagePixelInputs(TypedDict): ...@@ -57,7 +57,7 @@ class DeepseekVL2ImagePixelInputs(TypedDict):
class DeepseekVL2VImageEmbeddingInputs(TypedDict): class DeepseekVL2VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)` """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
...@@ -394,8 +394,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -394,8 +394,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return model return model
def _validate_pixel_values( def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.vision_config.image_size h = w = self.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -415,8 +415,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -415,8 +415,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return data return data
def _validate_images_spatial_crop( def _validate_images_spatial_crop(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
expected_dims = 2 expected_dims = 2
def _validate_shape(d: torch.Tensor): def _validate_shape(d: torch.Tensor):
...@@ -640,8 +640,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -640,8 +640,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, autoloaded_weights = loader.load_weights(weights,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -183,7 +184,7 @@ class EAGLE(nn.Module): ...@@ -183,7 +184,7 @@ class EAGLE(nn.Module):
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a # due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights: # Llama model. Here's a compatible version with the same weights:
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights.""" """Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -102,7 +103,7 @@ class ExaoneAttention(nn.Module): ...@@ -102,7 +103,7 @@ class ExaoneAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -196,7 +197,7 @@ class ExaoneBlockAttention(nn.Module): ...@@ -196,7 +197,7 @@ class ExaoneBlockAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -282,7 +283,7 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -282,7 +283,7 @@ class ExaoneDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -384,8 +385,8 @@ class ExaoneModel(nn.Module): ...@@ -384,8 +385,8 @@ class ExaoneModel(nn.Module):
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -395,7 +396,7 @@ class ExaoneModel(nn.Module): ...@@ -395,7 +396,7 @@ class ExaoneModel(nn.Module):
(".gate_up_proj", ".c_fc_1", 1), (".gate_up_proj", ".c_fc_1", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -535,8 +536,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -535,8 +536,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# With tie_word_embeddings, we can skip lm_head.weight # With tie_word_embeddings, we can skip lm_head.weight
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
"""Llama model for fairseq2 weights.""" """Llama model for fairseq2 weights."""
from typing import Iterable, Set, Tuple from collections.abc import Iterable
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
...@@ -44,8 +44,8 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): ...@@ -44,8 +44,8 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
f"model.{self.tp_rank}.pt", f"model.{self.tp_rank}.pt",
] ]
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
# fairseq2's serialization adds a wrapper to usual .pt state_dict's: # fairseq2's serialization adds a wrapper to usual .pt state_dict's:
# { "model_key": my_model_name, "my_model_name": state_dict } # { "model_key": my_model_name, "my_model_name": state_dict }
# which we first need to unpack # which we first need to unpack
...@@ -102,7 +102,7 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM): ...@@ -102,7 +102,7 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
name: str, name: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
params: dict[str, Parameter], params: dict[str, Parameter],
) -> Tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
"""Reshape fairseq2's weights.""" """Reshape fairseq2's weights."""
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor: def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
"""PyTorch Falcon model.""" """PyTorch Falcon model."""
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -394,8 +395,8 @@ class FalconModel(nn.Module): ...@@ -394,8 +395,8 @@ class FalconModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
total_num_heads = self.config.num_attention_heads total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads total_num_kv_heads = self.config.num_kv_heads
...@@ -405,7 +406,7 @@ class FalconModel(nn.Module): ...@@ -405,7 +406,7 @@ class FalconModel(nn.Module):
total_num_kv_heads = total_num_heads total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
...@@ -498,8 +499,8 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -498,8 +499,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -713,8 +713,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): ...@@ -713,8 +713,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -723,7 +723,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only): ...@@ -723,7 +723,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -922,8 +922,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -922,8 +922,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
'Florence2 only supports COSINE as temporal embedding.') 'Florence2 only supports COSINE as temporal embedding.')
def _validate_pixel_values( def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
size = self.processor_config["size"] size = self.processor_config["size"]
h, w = size["height"], size["width"] h, w = size["height"], size["width"]
...@@ -944,12 +944,12 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -944,12 +944,12 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
return data return data
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[List[List[torch.Tensor]], pixel_values: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"pixel_values", None) "pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]], image_embeds: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"image_embeds", None) "image_embeds", None)
...@@ -1096,7 +1096,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1096,7 +1096,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict from typing import Literal, Optional, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -382,7 +382,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -382,7 +382,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.lm_head, hidden_states, sampling_metadata) self.language_model.lm_head, hidden_states, sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import cache from functools import cache
from typing import Iterable, Optional, Set, Tuple, Union from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -231,7 +232,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -231,7 +232,7 @@ class GemmaDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -318,8 +319,8 @@ class GemmaModel(nn.Module): ...@@ -318,8 +319,8 @@ class GemmaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -329,7 +330,7 @@ class GemmaModel(nn.Module): ...@@ -329,7 +330,7 @@ class GemmaModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping: for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
...@@ -413,8 +414,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -413,8 +414,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -218,7 +219,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -218,7 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -305,8 +306,8 @@ class Gemma2Model(nn.Module): ...@@ -305,8 +306,8 @@ class Gemma2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -316,7 +317,7 @@ class Gemma2Model(nn.Module): ...@@ -316,7 +317,7 @@ class Gemma2Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
...@@ -413,8 +414,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -413,8 +414,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -320,7 +321,7 @@ class Gemma3DecoderLayer(nn.Module): ...@@ -320,7 +321,7 @@ class Gemma3DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -412,8 +413,8 @@ class Gemma3Model(nn.Module): ...@@ -412,8 +413,8 @@ class Gemma3Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -423,7 +424,7 @@ class Gemma3Model(nn.Module): ...@@ -423,7 +424,7 @@ class Gemma3Model(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
...@@ -521,8 +522,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -521,8 +522,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict from typing import Any, Literal, Optional, TypedDict
import torch import torch
from torch import nn from torch import nn
...@@ -701,8 +701,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -701,8 +701,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" """Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -60,7 +61,7 @@ class Glm4Attention(nn.Module): ...@@ -60,7 +61,7 @@ class Glm4Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None, rope_scaling: Optional[tuple] = None,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None: attn_type: str = AttentionType.DECODER) -> None:
super().__init__() super().__init__()
...@@ -183,7 +184,7 @@ class Glm4DecoderLayer(nn.Module): ...@@ -183,7 +184,7 @@ class Glm4DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -293,8 +294,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -293,8 +294,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -280,10 +281,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -280,10 +281,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name: if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask. # Skip attention mask.
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -243,10 +244,10 @@ class GPTBigCodeModel(nn.Module): ...@@ -243,10 +244,10 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ".attn.bias" in name: if ".attn.bias" in name:
# Skip attention mask. # Skip attention mask.
...@@ -327,8 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -327,8 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."]), skip_prefixes=(["lm_head."]),
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -228,8 +229,8 @@ class GPTJModel(nn.Module): ...@@ -228,8 +229,8 @@ class GPTJModel(nn.Module):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -239,7 +240,7 @@ class GPTJModel(nn.Module): ...@@ -239,7 +240,7 @@ class GPTJModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
...@@ -331,7 +332,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -331,7 +332,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -240,10 +241,10 @@ class GPTNeoXModel(nn.Module): ...@@ -240,10 +241,10 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
...@@ -324,7 +325,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -324,7 +325,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights.""" """Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -97,7 +98,7 @@ class GraniteAttention(nn.Module): ...@@ -97,7 +98,7 @@ class GraniteAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -230,7 +231,7 @@ class GraniteDecoderLayer(nn.Module): ...@@ -230,7 +231,7 @@ class GraniteDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -321,8 +322,8 @@ class GraniteModel(nn.Module): ...@@ -321,8 +322,8 @@ class GraniteModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -332,7 +333,7 @@ class GraniteModel(nn.Module): ...@@ -332,7 +333,7 @@ class GraniteModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
...@@ -475,8 +476,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -475,8 +476,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = [ skip_prefixes = [
"rotary_emb.inv_freq", "rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only IBM Granite speeech model.""" """Inference-only IBM Granite speeech model."""
import math import math
from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union from collections.abc import Iterable, Mapping
from typing import Optional, TypedDict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -763,8 +764,8 @@ class GraniteSpeechForConditionalGeneration( ...@@ -763,8 +764,8 @@ class GraniteSpeechForConditionalGeneration(
def load_weights( def load_weights(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Set[str]: ) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -305,8 +306,8 @@ class GraniteMoeModel(nn.Module): ...@@ -305,8 +306,8 @@ class GraniteMoeModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
new_weights = {} new_weights = {}
for n, p in weights: for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'): if n.endswith('.block_sparse_moe.input_linear.weight'):
...@@ -425,8 +426,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -425,8 +426,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment