Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional
from typing import Optional
from transformers import SmolVLMProcessor
......@@ -21,7 +21,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(
self,
*,
max_image_size: Optional[Dict[str, int]] = None,
max_image_size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> SmolVLMProcessor:
if max_image_size is not None:
......
......@@ -23,7 +23,8 @@
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -101,7 +102,7 @@ class SolarAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -236,7 +237,7 @@ class SolarDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -437,8 +438,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -448,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -20,7 +20,8 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -180,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -252,8 +253,8 @@ class StableLMEpochModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -263,7 +264,7 @@ class StableLMEpochModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -335,8 +336,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# Models trained using ColossalAI may include these tensors in
......
......@@ -19,7 +19,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -255,8 +256,8 @@ class Starcoder2Model(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -265,7 +266,7 @@ class Starcoder2Model(nn.Module):
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -342,8 +343,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# Models trained using ColossalAI may include these tensors in
......
......@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, Set, Tuple
from collections.abc import Iterable
import torch
import torch.nn as nn
......@@ -50,14 +50,14 @@ class TeleChat2Model(LlamaModel):
layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
('gate_up_proj', 'gate_proj', 0),
('gate_up_proj', 'up_proj', 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights:
......@@ -128,8 +128,8 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
layer_type: type[nn.Module] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
......
......@@ -15,7 +15,8 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, Literal, Optional, Union
from collections.abc import Iterable
from typing import Literal, Optional, Union
import torch
from torch import nn
......
......@@ -3,7 +3,7 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -619,8 +619,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."])
......
# SPDX-License-Identifier: Apache-2.0
import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, Union, overload)
from typing import Callable, Literal, Optional, Protocol, Union, overload
import torch
import torch.nn as nn
......@@ -58,8 +58,8 @@ class WeightsMapper:
return key
def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
......@@ -84,8 +84,8 @@ class AutoWeightsLoader:
self,
module: nn.Module,
*,
skip_prefixes: Optional[List[str]] = None,
ignore_unexpected_prefixes: Optional[List[str]] = None,
skip_prefixes: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None:
super().__init__()
......@@ -95,8 +95,8 @@ class AutoWeightsLoader:
def _groupby_prefix(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
......@@ -129,7 +129,7 @@ class AutoWeightsLoader:
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
......@@ -159,7 +159,7 @@ class AutoWeightsLoader:
yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]):
child_params: dict[str, torch.Tensor]):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
......@@ -182,7 +182,7 @@ class AutoWeightsLoader:
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]:
if isinstance(module, PPMissingLayer):
return
......@@ -251,10 +251,10 @@ class AutoWeightsLoader:
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> Set[str]:
) -> set[str]:
if mapper is not None:
weights = mapper.apply(weights)
......@@ -292,13 +292,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor:
@overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]:
def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
...
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: Literal[True],
) -> torch.Tensor:
......@@ -307,18 +307,18 @@ def flatten_bn(
@overload
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
...
def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor],
x: Union[list[torch.Tensor], torch.Tensor],
*,
concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
......@@ -442,7 +442,7 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, List[int]],
placeholder_token_id: Union[int, list[int]],
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
......@@ -596,7 +596,7 @@ def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]:
) -> tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
......@@ -614,10 +614,10 @@ def make_layers(
# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
_model_to_pp_missing_layer_names: dict[int, list[str]] = {}
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
......@@ -645,7 +645,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
for missing_layer_name in get_pp_missing_layer_names(model))
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
def make_empty_intermediate_tensors(
batch_size: int,
......@@ -684,7 +684,7 @@ def extract_layer_index(layer_name: str) -> int:
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
int_vals: list[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
......
......@@ -2,7 +2,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
from typing import Optional, TypedDict, Union
import torch
from torch import nn
......@@ -382,7 +382,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]):
hidden_states = []
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
......@@ -460,7 +460,7 @@ class WhisperModel(nn.Module):
def forward(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
) -> torch.Tensor:
......@@ -474,14 +474,14 @@ class WhisperModel(nn.Module):
def get_encoder_outputs(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]],
) -> Optional[torch.Tensor]:
if input_features is None:
return None
return self.encoder(input_features)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
......@@ -491,7 +491,7 @@ class WhisperModel(nn.Module):
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -722,8 +722,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
# add fake zeros bias for k_proj to state_dict
......@@ -732,8 +732,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def _create_fake_bias_for_k_proj(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
......
......@@ -6,8 +6,9 @@ https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
architectures in a hybrid model optimized for efficient sequence modeling. The
model alternates between state space model layers and attention-based layers.
"""
from collections.abc import Iterable
from itertools import cycle
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
......@@ -54,7 +55,7 @@ class Zamba2LoRA(nn.Module):
self,
input_dim: int,
rank: int,
output_dim: Union[int, List[int]],
output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None,
):
"""Initialize the attention layer.
......@@ -279,7 +280,7 @@ class Zamba2MLP(nn.Module):
self,
config: Zamba2Config,
bare_block_idx: int,
num_hybrid_layers: Dict[int, int],
num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None,
) -> None:
"""Initialize the MLP layer.
......@@ -769,8 +770,8 @@ class Zamba2Model(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -779,7 +780,7 @@ class Zamba2Model(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for chkpt_weight_name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in chkpt_weight_name:
......@@ -914,9 +915,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str,
def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
torch.Tensor],
**kwargs) -> Dict[str, torch.Tensor]:
**kwargs) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture.
Args:
......@@ -930,7 +931,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> Dict[str, torch.Tensor]:
self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture.
Args:
......@@ -941,7 +942,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Returns:
......@@ -1001,7 +1002,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment