Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -200,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -352,7 +353,7 @@ class DeepseekV2MLAAttention(nn.Module):
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -736,8 +737,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
......@@ -753,7 +754,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -4,7 +4,7 @@
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -45,7 +45,7 @@ _IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
"""
......@@ -57,7 +57,7 @@ class DeepseekVL2ImagePixelInputs(TypedDict):
class DeepseekVL2VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
......@@ -394,8 +394,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return model
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.vision_config.image_size
expected_dims = (3, h, w)
......@@ -415,8 +415,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return data
def _validate_images_spatial_crop(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
expected_dims = 2
def _validate_shape(d: torch.Tensor):
......@@ -640,8 +640,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
......@@ -183,7 +184,7 @@ class EAGLE(nn.Module):
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
......
......@@ -24,7 +24,8 @@
# limitations under the License.
"""Inference-only Exaone model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -102,7 +103,7 @@ class ExaoneAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -196,7 +197,7 @@ class ExaoneBlockAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -282,7 +283,7 @@ class ExaoneDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -384,8 +385,8 @@ class ExaoneModel(nn.Module):
hidden_states, _ = self.ln_f(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -395,7 +396,7 @@ class ExaoneModel(nn.Module):
(".gate_up_proj", ".c_fc_1", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -535,8 +536,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# With tie_word_embeddings, we can skip lm_head.weight
......
......@@ -16,7 +16,7 @@
# limitations under the License.
"""Llama model for fairseq2 weights."""
from typing import Iterable, Set, Tuple
from collections.abc import Iterable
import torch
from torch.nn import Parameter
......@@ -44,8 +44,8 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
f"model.{self.tp_rank}.pt",
]
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
# { "model_key": my_model_name, "my_model_name": state_dict }
# which we first need to unpack
......@@ -102,7 +102,7 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
name: str,
loaded_weight: torch.Tensor,
params: dict[str, Parameter],
) -> Tuple[str, torch.Tensor]:
) -> tuple[str, torch.Tensor]:
"""Reshape fairseq2's weights."""
def permute(w: torch.Tensor, n_heads: int) -> torch.Tensor:
......
......@@ -20,7 +20,8 @@
"""PyTorch Falcon model."""
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -394,8 +395,8 @@ class FalconModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
total_num_heads = self.config.num_attention_heads
if self.config.new_decoder_architecture:
total_num_kv_heads = self.config.num_kv_heads
......@@ -405,7 +406,7 @@ class FalconModel(nn.Module):
total_num_kv_heads = total_num_heads
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
......@@ -498,8 +499,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -3,7 +3,7 @@
import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -713,8 +713,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -723,7 +723,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -922,8 +922,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
'Florence2 only supports COSINE as temporal embedding.')
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
size = self.processor_config["size"]
h, w = size["height"], size["width"]
......@@ -944,12 +944,12 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
......@@ -1096,7 +1096,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, TypedDict
import torch
import torch.nn as nn
......@@ -382,7 +382,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -15,8 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import cache
from typing import Iterable, Optional, Set, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
......@@ -231,7 +232,7 @@ class GemmaDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -318,8 +319,8 @@ class GemmaModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -329,7 +330,7 @@ class GemmaModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
......@@ -413,8 +414,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -15,7 +15,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -218,7 +219,7 @@ class Gemma2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -305,8 +306,8 @@ class Gemma2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -316,7 +317,7 @@ class Gemma2Model(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
......@@ -413,8 +414,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -14,7 +14,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn.functional as F
......@@ -320,7 +321,7 @@ class Gemma3DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -412,8 +413,8 @@ class Gemma3Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -423,7 +424,7 @@ class Gemma3Model(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
......@@ -521,8 +522,8 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict
from typing import Any, Literal, Optional, TypedDict
import torch
from torch import nn
......@@ -701,8 +701,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
......@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -60,7 +61,7 @@ class Glm4Attention(nn.Module):
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
......@@ -183,7 +184,7 @@ class Glm4DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -293,8 +294,8 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -280,10 +281,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name or ".attn.masked_bias" in name:
# Skip attention mask.
......
......@@ -19,7 +19,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -243,10 +244,10 @@ class GPTBigCodeModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ".attn.bias" in name:
# Skip attention mask.
......@@ -327,8 +328,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]),
......
......@@ -17,7 +17,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -228,8 +229,8 @@ class GPTJModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -239,7 +240,7 @@ class GPTJModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue
......@@ -331,7 +332,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
sampling_metadata, self.lm_head.bias)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -17,7 +17,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -240,10 +241,10 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name):
......@@ -324,7 +325,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -97,7 +98,7 @@ class GraniteAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -230,7 +231,7 @@ class GraniteDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -321,8 +322,8 @@ class GraniteModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -332,7 +333,7 @@ class GraniteModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
......@@ -475,8 +476,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = [
"rotary_emb.inv_freq",
# Models trained using ColossalAI may include these tensors in
......
......@@ -23,7 +23,8 @@
# limitations under the License.
"""Inference-only IBM Granite speeech model."""
import math
from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union
from collections.abc import Iterable, Mapping
from typing import Optional, TypedDict, Union
import torch
import torch.nn.functional as F
......@@ -763,8 +764,8 @@ class GraniteSpeechForConditionalGeneration(
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Set[str]:
weights: Iterable[tuple[str, torch.Tensor]],
) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -305,8 +306,8 @@ class GraniteMoeModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
......@@ -425,8 +426,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment