Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional from typing import Optional
from transformers import SmolVLMProcessor from transformers import SmolVLMProcessor
...@@ -21,7 +21,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo): ...@@ -21,7 +21,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
max_image_size: Optional[Dict[str, int]] = None, max_image_size: Optional[dict[str, int]] = None,
**kwargs: object, **kwargs: object,
) -> SmolVLMProcessor: ) -> SmolVLMProcessor:
if max_image_size is not None: if max_image_size is not None:
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights.""" """Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -101,7 +102,7 @@ class SolarAttention(nn.Module): ...@@ -101,7 +102,7 @@ class SolarAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -236,7 +237,7 @@ class SolarDecoderLayer(nn.Module): ...@@ -236,7 +237,7 @@ class SolarDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -437,8 +438,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -437,8 +438,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -448,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -448,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -180,7 +181,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -180,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -252,8 +253,8 @@ class StableLMEpochModel(nn.Module): ...@@ -252,8 +253,8 @@ class StableLMEpochModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -263,7 +264,7 @@ class StableLMEpochModel(nn.Module): ...@@ -263,7 +264,7 @@ class StableLMEpochModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -335,8 +336,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -335,8 +336,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
......
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -255,8 +256,8 @@ class Starcoder2Model(nn.Module): ...@@ -255,8 +256,8 @@ class Starcoder2Model(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -265,7 +266,7 @@ class Starcoder2Model(nn.Module): ...@@ -265,7 +266,7 @@ class Starcoder2Model(nn.Module):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -342,8 +343,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -342,8 +343,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, Set, Tuple from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -50,14 +50,14 @@ class TeleChat2Model(LlamaModel): ...@@ -50,14 +50,14 @@ class TeleChat2Model(LlamaModel):
layer.mlp.gate_up_proj.bias = None layer.mlp.gate_up_proj.bias = None
layer.mlp.gate_up_proj.skip_bias_add = True layer.mlp.gate_up_proj.skip_bias_add = True
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
('gate_up_proj', 'gate_proj', 0), ('gate_up_proj', 'gate_proj', 0),
('gate_up_proj', 'up_proj', 1), ('gate_up_proj', 'up_proj', 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
total_num_heads = self.config.n_head total_num_heads = self.config.n_head
head_dim = self.config.hidden_size // total_num_heads head_dim = self.config.hidden_size // total_num_heads
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -128,8 +128,8 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): ...@@ -128,8 +128,8 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
layer_type: type[nn.Module] = LlamaDecoderLayer): layer_type: type[nn.Module] = LlamaDecoderLayer):
return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
......
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
"""Wrapper around `transformers` models""" """Wrapper around `transformers` models"""
import re import re
from typing import Iterable, Literal, Optional, Union from collections.abc import Iterable
from typing import Literal, Optional, Union
import torch import torch
from torch import nn from torch import nn
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -619,8 +619,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -619,8 +619,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."]) ignore_unexpected_prefixes=["audio_tower."])
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, from typing import Callable, Literal, Optional, Protocol, Union, overload
Protocol, Set, Tuple, Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -58,8 +58,8 @@ class WeightsMapper: ...@@ -58,8 +58,8 @@ class WeightsMapper:
return key return key
def apply( def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]] self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]: ) -> Iterable[tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None) if (out_name := self._map_name(name)) is not None)
...@@ -84,8 +84,8 @@ class AutoWeightsLoader: ...@@ -84,8 +84,8 @@ class AutoWeightsLoader:
self, self,
module: nn.Module, module: nn.Module,
*, *,
skip_prefixes: Optional[List[str]] = None, skip_prefixes: Optional[list[str]] = None,
ignore_unexpected_prefixes: Optional[List[str]] = None, ignore_unexpected_prefixes: Optional[list[str]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -95,8 +95,8 @@ class AutoWeightsLoader: ...@@ -95,8 +95,8 @@ class AutoWeightsLoader:
def _groupby_prefix( def _groupby_prefix(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]: ) -> Iterable[tuple[str, Iterable[tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data) weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights) for weight_name, weight_data in weights)
...@@ -129,7 +129,7 @@ class AutoWeightsLoader: ...@@ -129,7 +129,7 @@ class AutoWeightsLoader:
self, self,
base_prefix: str, base_prefix: str,
param: nn.Parameter, param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]: ) -> Iterable[str]:
for weight_name, weight_data in weights: for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name) weight_qualname = self._get_qualname(base_prefix, weight_name)
...@@ -159,7 +159,7 @@ class AutoWeightsLoader: ...@@ -159,7 +159,7 @@ class AutoWeightsLoader:
yield weight_qualname yield weight_qualname
def _add_loadable_non_param_tensors(self, module: nn.Module, def _add_loadable_non_param_tensors(self, module: nn.Module,
child_params: Dict[str, torch.Tensor]): child_params: dict[str, torch.Tensor]):
""" """
Add tensor names that are not in the model params that may be in the Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats. safetensors, e.g., batch normalization stats.
...@@ -182,7 +182,7 @@ class AutoWeightsLoader: ...@@ -182,7 +182,7 @@ class AutoWeightsLoader:
self, self,
base_prefix: str, base_prefix: str,
module: nn.Module, module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[str]: ) -> Iterable[str]:
if isinstance(module, PPMissingLayer): if isinstance(module, PPMissingLayer):
return return
...@@ -251,10 +251,10 @@ class AutoWeightsLoader: ...@@ -251,10 +251,10 @@ class AutoWeightsLoader:
def load_weights( def load_weights(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
*, *,
mapper: Optional[WeightsMapper] = None, mapper: Optional[WeightsMapper] = None,
) -> Set[str]: ) -> set[str]:
if mapper is not None: if mapper is not None:
weights = mapper.apply(weights) weights = mapper.apply(weights)
...@@ -292,13 +292,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor: ...@@ -292,13 +292,13 @@ def flatten_bn(x: torch.Tensor) -> torch.Tensor:
@overload @overload
def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: def flatten_bn(x: list[torch.Tensor]) -> list[torch.Tensor]:
... ...
@overload @overload
def flatten_bn( def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor], x: Union[list[torch.Tensor], torch.Tensor],
*, *,
concat: Literal[True], concat: Literal[True],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -307,18 +307,18 @@ def flatten_bn( ...@@ -307,18 +307,18 @@ def flatten_bn(
@overload @overload
def flatten_bn( def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor], x: Union[list[torch.Tensor], torch.Tensor],
*, *,
concat: bool = False, concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
... ...
def flatten_bn( def flatten_bn(
x: Union[List[torch.Tensor], torch.Tensor], x: Union[list[torch.Tensor], torch.Tensor],
*, *,
concat: bool = False, concat: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[list[torch.Tensor], torch.Tensor]:
""" """
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
...@@ -442,7 +442,7 @@ def merge_multimodal_embeddings( ...@@ -442,7 +442,7 @@ def merge_multimodal_embeddings(
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
multimodal_embeddings: NestedTensors, multimodal_embeddings: NestedTensors,
placeholder_token_id: Union[int, List[int]], placeholder_token_id: Union[int, list[int]],
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
...@@ -596,7 +596,7 @@ def make_layers( ...@@ -596,7 +596,7 @@ def make_layers(
num_hidden_layers: int, num_hidden_layers: int,
layer_fn: LayerFn, layer_fn: LayerFn,
prefix: str, prefix: str,
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking """Make a list of layers with the given layer function, taking
pipeline parallelism into account. pipeline parallelism into account.
""" """
...@@ -614,10 +614,10 @@ def make_layers( ...@@ -614,10 +614,10 @@ def make_layers(
# NOTE: don't use lru_cache here because it can prevent garbage collection # NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} _model_to_pp_missing_layer_names: dict[int, list[str]] = {}
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
"""Get the names of the missing layers in a pipeline parallel model.""" """Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model) model_id = id(model)
if model_id in _model_to_pp_missing_layer_names: if model_id in _model_to_pp_missing_layer_names:
...@@ -645,7 +645,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: ...@@ -645,7 +645,7 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
for missing_layer_name in get_pp_missing_layer_names(model)) for missing_layer_name in get_pp_missing_layer_names(model))
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors_factory(keys: list[str], hidden_size: int):
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
batch_size: int, batch_size: int,
...@@ -684,7 +684,7 @@ def extract_layer_index(layer_name: str) -> int: ...@@ -684,7 +684,7 @@ def extract_layer_index(layer_name: str) -> int:
- "model.encoder.layers.0.sub.1" -> ValueError - "model.encoder.layers.0.sub.1" -> ValueError
""" """
subnames = layer_name.split(".") subnames = layer_name.split(".")
int_vals: List[int] = [] int_vals: list[int] = []
for subname in subnames: for subname in subnames:
try: try:
int_vals.append(int(subname)) int_vals.append(int(subname))
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union from typing import Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -382,7 +382,7 @@ class WhisperEncoder(nn.Module): ...@@ -382,7 +382,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_( self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)) sinusoids(*self.embed_positions.weight.shape))
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]): def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]):
hidden_states = [] hidden_states = []
for features in input_features: for features in input_features:
embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv1(features))
...@@ -460,7 +460,7 @@ class WhisperModel(nn.Module): ...@@ -460,7 +460,7 @@ class WhisperModel(nn.Module):
def forward( def forward(
self, self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]],
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -474,14 +474,14 @@ class WhisperModel(nn.Module): ...@@ -474,14 +474,14 @@ class WhisperModel(nn.Module):
def get_encoder_outputs( def get_encoder_outputs(
self, self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, list[torch.Tensor]]],
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if input_features is None: if input_features is None:
return None return None
return self.encoder(input_features) return self.encoder(input_features)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"), (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
...@@ -491,7 +491,7 @@ class WhisperModel(nn.Module): ...@@ -491,7 +491,7 @@ class WhisperModel(nn.Module):
(".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"), (".encoder_attn.kv_proj", ".encoder_attn.v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -722,8 +722,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -722,8 +722,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
# add fake zeros bias for k_proj to state_dict # add fake zeros bias for k_proj to state_dict
...@@ -732,8 +732,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -732,8 +732,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def _create_fake_bias_for_k_proj( def _create_fake_bias_for_k_proj(
weights: Iterable[Tuple[str, torch.Tensor]] weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]: ) -> Iterable[tuple[str, torch.Tensor]]:
""" """
Create full zeros bias for k_proj weight in self-attn and x-attn layers. Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros. So that the bias for k_proj in qkv_proj can be initialized with zeros.
......
...@@ -6,8 +6,9 @@ https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer ...@@ -6,8 +6,9 @@ https://arxiv.org/abs/2411.15242, which combines Mamba and Transformer
architectures in a hybrid model optimized for efficient sequence modeling. The architectures in a hybrid model optimized for efficient sequence modeling. The
model alternates between state space model layers and attention-based layers. model alternates between state space model layers and attention-based layers.
""" """
from collections.abc import Iterable
from itertools import cycle from itertools import cycle
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -54,7 +55,7 @@ class Zamba2LoRA(nn.Module): ...@@ -54,7 +55,7 @@ class Zamba2LoRA(nn.Module):
self, self,
input_dim: int, input_dim: int,
rank: int, rank: int,
output_dim: Union[int, List[int]], output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
"""Initialize the attention layer. """Initialize the attention layer.
...@@ -279,7 +280,7 @@ class Zamba2MLP(nn.Module): ...@@ -279,7 +280,7 @@ class Zamba2MLP(nn.Module):
self, self,
config: Zamba2Config, config: Zamba2Config,
bare_block_idx: int, bare_block_idx: int,
num_hybrid_layers: Dict[int, int], num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
"""Initialize the MLP layer. """Initialize the MLP layer.
...@@ -769,8 +770,8 @@ class Zamba2Model(nn.Module): ...@@ -769,8 +770,8 @@ class Zamba2Model(nn.Module):
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -779,7 +780,7 @@ class Zamba2Model(nn.Module): ...@@ -779,7 +780,7 @@ class Zamba2Model(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for chkpt_weight_name, loaded_weight in weights: for chkpt_weight_name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in chkpt_weight_name: if weight_name not in chkpt_weight_name:
...@@ -914,9 +915,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -914,9 +915,9 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return hidden_states return hidden_states
def copy_inputs_before_cuda_graphs(self, input_buffers: Dict[str, def copy_inputs_before_cuda_graphs(self, input_buffers: dict[str,
torch.Tensor], torch.Tensor],
**kwargs) -> Dict[str, torch.Tensor]: **kwargs) -> dict[str, torch.Tensor]:
"""Copy inputs before CUDA graph capture. """Copy inputs before CUDA graph capture.
Args: Args:
...@@ -930,7 +931,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -930,7 +931,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
input_buffers, **kwargs) input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs( def get_seqlen_agnostic_capture_inputs(
self, batch_size: int) -> Dict[str, torch.Tensor]: self, batch_size: int) -> dict[str, torch.Tensor]:
"""Get inputs for sequence-length-agnostic graph capture. """Get inputs for sequence-length-agnostic graph capture.
Args: Args:
...@@ -941,7 +942,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -941,7 +942,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches. """Calculate shapes for Mamba's convolutional and state caches.
Returns: Returns:
...@@ -1001,7 +1002,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): ...@@ -1001,7 +1002,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment