Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0
"""Inference-only GraniteMoeHybrid model."""
# Added by the IBM Team, 2025
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -381,10 +382,10 @@ class GraniteMoeHybridModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
def _load(n, p):
param = params_dict[n]
......@@ -538,7 +539,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
......@@ -578,7 +579,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -4,7 +4,8 @@
The architecture is the same as granitemoe but with the addition of shared
experts.
"""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -208,8 +209,8 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
new_weights = {}
for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'):
......@@ -329,8 +330,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn.functional as F
......@@ -263,7 +264,7 @@ class Grok1DecoderLayer(nn.Module):
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -340,7 +341,7 @@ class Grok1Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -371,8 +372,8 @@ class Grok1Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -390,7 +391,7 @@ class Grok1Model(nn.Module):
num_experts=num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
......@@ -528,7 +529,7 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
......@@ -547,8 +548,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
skip_prefixes = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings:
......
......@@ -17,7 +17,8 @@
# limitations under the License.
"""PyTorch Idefics2 model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -342,8 +343,8 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -351,7 +352,7 @@ class Idefics2VisionTransformer(nn.Module):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
......
......@@ -17,7 +17,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -85,7 +85,7 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
size: Optional[Dict[str, int]] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Idefics3Processor:
if size is not None:
......@@ -752,8 +752,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable)
import torch
from torch import Tensor
......@@ -102,7 +102,7 @@ class _SupportsMultiModalType(Protocol):
@overload
def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
...
......@@ -112,8 +112,8 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def supports_multimodal(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type):
return isinstance(model, _SupportsMultiModalType)
......@@ -134,9 +134,9 @@ class SupportsLoRA(Protocol):
"""
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules: ClassVar[Dict[str, str]] = {}
embedding_padding_modules: ClassVar[List[str]] = []
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
embedding_modules: ClassVar[dict[str, str]] = {}
embedding_padding_modules: ClassVar[list[str]] = []
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
# We can't use runtime_checkable with ClassVar for issubclass checks
......@@ -145,13 +145,13 @@ class SupportsLoRA(Protocol):
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
packed_modules_mapping: dict[str, list[str]]
embedding_modules: dict[str, str]
embedding_padding_modules: list[str]
@overload
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
...
......@@ -161,8 +161,8 @@ def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
......@@ -191,7 +191,7 @@ def supports_lora(
return result
def _supports_lora(model: Union[Type[object], object]) -> bool:
def _supports_lora(model: Union[type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
......@@ -256,7 +256,7 @@ class _SupportsPPType(Protocol):
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
...
......@@ -266,8 +266,8 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]:
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
......@@ -298,14 +298,14 @@ def supports_pp(
return supports_attributes and supports_inspect
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
......@@ -336,13 +336,13 @@ def has_inner_state(model: object) -> TypeIs[HasInnerState]:
@overload
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
...
def has_inner_state(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type):
return isinstance(model, _HasInnerStateType)
......@@ -373,13 +373,13 @@ def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
...
def is_attention_free(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType)
......@@ -410,13 +410,13 @@ def is_hybrid(model: object) -> TypeIs[IsHybrid]:
@overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
...
def is_hybrid(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type):
return isinstance(model, _IsHybridType)
......@@ -439,13 +439,13 @@ def has_noops(model: object) -> TypeIs[HasNoOps]:
@overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]:
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
...
def has_noops(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]:
model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type):
return isinstance(model, _HasNoOpsType)
......@@ -461,7 +461,7 @@ class SupportsCrossEncoding(Protocol):
@overload
def supports_cross_encoding(
model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
...
......@@ -471,8 +471,8 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def _supports_cross_encoding(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
if isinstance(model, type):
return isinstance(model, SupportsCrossEncoding)
......@@ -481,15 +481,15 @@ def _supports_cross_encoding(
def supports_cross_encoding(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_pooling_model(model) and _supports_cross_encoding(model)
class SupportsQuant:
"""The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> Self:
......@@ -525,7 +525,7 @@ class SupportsTranscription(Protocol):
@overload
def supports_transcription(
model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
model: type[object]) -> TypeIs[type[SupportsTranscription]]:
...
......@@ -535,8 +535,8 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def supports_transcription(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type):
return isinstance(model, SupportsTranscription)
......@@ -551,7 +551,7 @@ class SupportsV0Only(Protocol):
@overload
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]:
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
...
......@@ -561,8 +561,8 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def supports_v0_only(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
if isinstance(model, type):
return isinstance(model, SupportsV0Only)
......
# SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload,
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
runtime_checkable)
import torch
......@@ -20,7 +20,7 @@ logger = init_logger(__name__)
# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
# which has T = list[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
......@@ -48,12 +48,12 @@ class VllmModel(Protocol[T_co]):
...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
model_init = model.__init__
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
......@@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
...
......@@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
def is_vllm_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
......@@ -105,7 +105,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
@overload
def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
...
......@@ -116,8 +116,8 @@ def is_text_generation_model(
def is_text_generation_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model):
return False
......@@ -142,7 +142,7 @@ class VllmModelForPooling(VllmModel[T], Protocol[T]):
@overload
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]:
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
...
......@@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
def is_pooling_model(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model):
return False
......
......@@ -6,8 +6,9 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable
from functools import partial
from typing import Iterable, Optional, Set, Tuple
from typing import Optional
import torch
import torch.nn as nn
......@@ -461,10 +462,10 @@ class InternVisionModel(nn.Module):
return encoder_outputs
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable
from functools import partial
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -81,7 +82,7 @@ class InternLM2Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -225,7 +226,7 @@ class InternLMDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -252,7 +253,7 @@ class InternLM2Model(nn.Module):
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer):
layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
......@@ -316,7 +317,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model):
model_type: type[InternLM2Model] = InternLM2Model):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -361,15 +362,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -407,7 +408,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model,
model_type: type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
......@@ -66,7 +66,7 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......
......@@ -8,7 +8,7 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
from typing import Literal, Optional, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
......@@ -932,8 +932,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
# unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
skip_prefixes = [
"action_embed", "temporal_embed", "track_embed",
......
......@@ -21,7 +21,8 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -333,10 +334,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -442,7 +443,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape = (
......@@ -464,8 +465,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -482,7 +483,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -583,7 +584,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
super().load_weights(weights)
......
......@@ -43,10 +43,9 @@
import copy
import math
from collections.abc import Mapping
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple,
TypedDict, Union)
from typing import Any, Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -120,7 +119,7 @@ class KimiVLMultiModalProjector(nn.Module):
class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, List[torch.Tensor]]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:`(num_patches, num_channels, patch_size, patch_size)`
"""
......@@ -447,7 +446,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata, **kwargs)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -103,7 +104,7 @@ class LlamaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -285,7 +286,7 @@ class LlamaDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -394,8 +395,8 @@ class LlamaModel(nn.Module):
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -405,7 +406,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -582,8 +583,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......@@ -599,7 +600,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
......
......@@ -16,7 +16,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Any, Optional
import torch
from torch import nn
......@@ -48,7 +49,7 @@ class Llama4MoE(nn.Module):
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
# psuedo-standard is that the router scores are floats
router_scores = torch.sigmoid(router_scores.float())
......@@ -115,7 +116,7 @@ class Llama4Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -300,7 +301,7 @@ class Llama4DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -335,9 +336,9 @@ class Llama4Model(LlamaModel):
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
params_dict: dict[str, nn.Parameter],
loaded_params: set[str],
expert_params_mapping: list[tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
......@@ -390,8 +391,8 @@ class Llama4Model(LlamaModel):
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -412,7 +413,7 @@ class Llama4Model(LlamaModel):
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
......@@ -489,8 +490,8 @@ class Llama4ForCausalLM(LlamaForCausalLM):
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......@@ -506,7 +507,7 @@ class Llama4ForCausalLM(LlamaForCausalLM):
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple
from collections.abc import Iterable
import torch
import torch.nn as nn
......@@ -92,8 +92,8 @@ class LlamaModel(nn.Module):
hidden_states = hidden_states + residual
return hidden_states, hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -103,7 +103,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
......@@ -150,7 +150,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
......@@ -56,7 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
embeds: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
embeds = self.input_layernorm(embeds)
......@@ -140,8 +141,8 @@ class LlamaModel(nn.Module):
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -151,7 +152,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if 'midlayer.' in name:
name = name.replace('midlayer.', 'layers.0.')
......@@ -228,7 +229,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
......
......@@ -2,8 +2,8 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union, cast)
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
Union, cast)
import torch
import torch.nn as nn
......@@ -751,8 +751,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
Union)
import torch
import torch.nn as nn
......@@ -266,8 +267,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -450,7 +451,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
......@@ -577,7 +578,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment