Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only GraniteMoeHybrid model.""" """Inference-only GraniteMoeHybrid model."""
# Added by the IBM Team, 2025 # Added by the IBM Team, 2025
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -381,10 +382,10 @@ class GraniteMoeHybridModel(nn.Module): ...@@ -381,10 +382,10 @@ class GraniteMoeHybridModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
def _load(n, p): def _load(n, p):
param = params_dict[n] param = params_dict[n]
...@@ -538,7 +539,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -538,7 +539,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
...@@ -578,7 +579,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, ...@@ -578,7 +579,7 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
The architecture is the same as granitemoe but with the addition of shared The architecture is the same as granitemoe but with the addition of shared
experts. experts.
""" """
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -208,8 +209,8 @@ class GraniteMoeSharedModel(nn.Module): ...@@ -208,8 +209,8 @@ class GraniteMoeSharedModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
new_weights = {} new_weights = {}
for n, p in weights: for n, p in weights:
if n.endswith('.block_sparse_moe.input_linear.weight'): if n.endswith('.block_sparse_moe.input_linear.weight'):
...@@ -329,8 +330,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -329,8 +330,8 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Grok1 model.""" """Inference-only Grok1 model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -263,7 +264,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -263,7 +264,7 @@ class Grok1DecoderLayer(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -340,7 +341,7 @@ class Grok1Model(nn.Module): ...@@ -340,7 +341,7 @@ class Grok1Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -371,8 +372,8 @@ class Grok1Model(nn.Module): ...@@ -371,8 +372,8 @@ class Grok1Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -390,7 +391,7 @@ class Grok1Model(nn.Module): ...@@ -390,7 +391,7 @@ class Grok1Model(nn.Module):
num_experts=num_experts) num_experts=num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
...@@ -528,7 +529,7 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -528,7 +529,7 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
...@@ -547,8 +548,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -547,8 +548,8 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
skip_prefixes = ["rotary_emb.inv_freq"] skip_prefixes = ["rotary_emb.inv_freq"]
# Skip lm_head when tie_word_embeddings is True # Skip lm_head when tie_word_embeddings is True
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Idefics2 model.""" """PyTorch Idefics2 model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -342,8 +343,8 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -342,8 +343,8 @@ class Idefics2VisionTransformer(nn.Module):
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -351,7 +352,7 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -351,7 +352,7 @@ class Idefics2VisionTransformer(nn.Module):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.encoder.layers) layer_count = len(self.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -85,7 +85,7 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -85,7 +85,7 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor( def get_hf_processor(
self, self,
*, *,
size: Optional[Dict[str, int]] = None, size: Optional[dict[str, int]] = None,
**kwargs: object, **kwargs: object,
) -> Idefics3Processor: ) -> Idefics3Processor:
if size is not None: if size is not None:
...@@ -752,8 +752,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -752,8 +752,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Protocol, Type, Union, overload, runtime_checkable) Union, overload, runtime_checkable)
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -102,7 +102,7 @@ class _SupportsMultiModalType(Protocol): ...@@ -102,7 +102,7 @@ class _SupportsMultiModalType(Protocol):
@overload @overload
def supports_multimodal( def supports_multimodal(
model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
... ...
...@@ -112,8 +112,8 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ...@@ -112,8 +112,8 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
def supports_multimodal( def supports_multimodal(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: ) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsMultiModalType) return isinstance(model, _SupportsMultiModalType)
...@@ -134,9 +134,9 @@ class SupportsLoRA(Protocol): ...@@ -134,9 +134,9 @@ class SupportsLoRA(Protocol):
""" """
# The `embedding_module` and `embedding_padding_modules` # The `embedding_module` and `embedding_padding_modules`
# are empty by default. # are empty by default.
embedding_modules: ClassVar[Dict[str, str]] = {} embedding_modules: ClassVar[dict[str, str]] = {}
embedding_padding_modules: ClassVar[List[str]] = [] embedding_padding_modules: ClassVar[list[str]] = []
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
# We can't use runtime_checkable with ClassVar for issubclass checks # We can't use runtime_checkable with ClassVar for issubclass checks
...@@ -145,13 +145,13 @@ class SupportsLoRA(Protocol): ...@@ -145,13 +145,13 @@ class SupportsLoRA(Protocol):
class _SupportsLoRAType(Protocol): class _SupportsLoRAType(Protocol):
supports_lora: Literal[True] supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]] packed_modules_mapping: dict[str, list[str]]
embedding_modules: Dict[str, str] embedding_modules: dict[str, str]
embedding_padding_modules: List[str] embedding_padding_modules: list[str]
@overload @overload
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
... ...
...@@ -161,8 +161,8 @@ def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ...@@ -161,8 +161,8 @@ def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
def supports_lora( def supports_lora(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: ) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
result = _supports_lora(model) result = _supports_lora(model)
if not result: if not result:
...@@ -191,7 +191,7 @@ def supports_lora( ...@@ -191,7 +191,7 @@ def supports_lora(
return result return result
def _supports_lora(model: Union[Type[object], object]) -> bool: def _supports_lora(model: Union[type[object], object]) -> bool:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsLoRAType) return isinstance(model, _SupportsLoRAType)
...@@ -256,7 +256,7 @@ class _SupportsPPType(Protocol): ...@@ -256,7 +256,7 @@ class _SupportsPPType(Protocol):
@overload @overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
... ...
...@@ -266,8 +266,8 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]: ...@@ -266,8 +266,8 @@ def supports_pp(model: object) -> TypeIs[SupportsPP]:
def supports_pp( def supports_pp(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: ) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model) supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model) supports_inspect = _supports_pp_inspect(model)
...@@ -298,14 +298,14 @@ def supports_pp( ...@@ -298,14 +298,14 @@ def supports_pp(
return supports_attributes and supports_inspect return supports_attributes and supports_inspect
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _SupportsPPType) return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP) return isinstance(model, SupportsPP)
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None) model_forward = getattr(model, "forward", None)
if not callable(model_forward): if not callable(model_forward):
return False return False
...@@ -336,13 +336,13 @@ def has_inner_state(model: object) -> TypeIs[HasInnerState]: ...@@ -336,13 +336,13 @@ def has_inner_state(model: object) -> TypeIs[HasInnerState]:
@overload @overload
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
... ...
def has_inner_state( def has_inner_state(
model: Union[Type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _HasInnerStateType) return isinstance(model, _HasInnerStateType)
...@@ -373,13 +373,13 @@ def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ...@@ -373,13 +373,13 @@ def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
@overload @overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
... ...
def is_attention_free( def is_attention_free(
model: Union[Type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _IsAttentionFreeType) return isinstance(model, _IsAttentionFreeType)
...@@ -410,13 +410,13 @@ def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...@@ -410,13 +410,13 @@ def is_hybrid(model: object) -> TypeIs[IsHybrid]:
@overload @overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]: def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
... ...
def is_hybrid( def is_hybrid(
model: Union[Type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]: ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _IsHybridType) return isinstance(model, _IsHybridType)
...@@ -439,13 +439,13 @@ def has_noops(model: object) -> TypeIs[HasNoOps]: ...@@ -439,13 +439,13 @@ def has_noops(model: object) -> TypeIs[HasNoOps]:
@overload @overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]: def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
... ...
def has_noops( def has_noops(
model: Union[Type[object], object] model: Union[type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]: ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, _HasNoOpsType) return isinstance(model, _HasNoOpsType)
...@@ -461,7 +461,7 @@ class SupportsCrossEncoding(Protocol): ...@@ -461,7 +461,7 @@ class SupportsCrossEncoding(Protocol):
@overload @overload
def supports_cross_encoding( def supports_cross_encoding(
model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
... ...
...@@ -471,8 +471,8 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ...@@ -471,8 +471,8 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
def _supports_cross_encoding( def _supports_cross_encoding(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, SupportsCrossEncoding) return isinstance(model, SupportsCrossEncoding)
...@@ -481,15 +481,15 @@ def _supports_cross_encoding( ...@@ -481,15 +481,15 @@ def _supports_cross_encoding(
def supports_cross_encoding( def supports_cross_encoding(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: ) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
class SupportsQuant: class SupportsQuant:
"""The interface required for all models that support quantization.""" """The interface required for all models that support quantization."""
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
quant_config: Optional[QuantizationConfig] = None quant_config: Optional[QuantizationConfig] = None
def __new__(cls, *args, **kwargs) -> Self: def __new__(cls, *args, **kwargs) -> Self:
...@@ -525,7 +525,7 @@ class SupportsTranscription(Protocol): ...@@ -525,7 +525,7 @@ class SupportsTranscription(Protocol):
@overload @overload
def supports_transcription( def supports_transcription(
model: Type[object]) -> TypeIs[Type[SupportsTranscription]]: model: type[object]) -> TypeIs[type[SupportsTranscription]]:
... ...
...@@ -535,8 +535,8 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: ...@@ -535,8 +535,8 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
def supports_transcription( def supports_transcription(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]: ) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, SupportsTranscription) return isinstance(model, SupportsTranscription)
...@@ -551,7 +551,7 @@ class SupportsV0Only(Protocol): ...@@ -551,7 +551,7 @@ class SupportsV0Only(Protocol):
@overload @overload
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]: def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
... ...
...@@ -561,8 +561,8 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: ...@@ -561,8 +561,8 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
def supports_v0_only( def supports_v0_only(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]: ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
if isinstance(model, type): if isinstance(model, type):
return isinstance(model, SupportsV0Only) return isinstance(model, SupportsV0Only)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import (TYPE_CHECKING, Optional, Protocol, Type, Union, overload, from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
runtime_checkable) runtime_checkable)
import torch import torch
...@@ -20,7 +20,7 @@ logger = init_logger(__name__) ...@@ -20,7 +20,7 @@ logger = init_logger(__name__)
# The type of hidden states # The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa # Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor] # which has T = list[torch.Tensor]
T = TypeVar("T", default=torch.Tensor) T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)
...@@ -48,12 +48,12 @@ class VllmModel(Protocol[T_co]): ...@@ -48,12 +48,12 @@ class VllmModel(Protocol[T_co]):
... ...
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
model_init = model.__init__ model_init = model.__init__
return supports_kw(model_init, "vllm_config") return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
model_forward = getattr(model, "forward", None) model_forward = getattr(model, "forward", None)
if not callable(model_forward): if not callable(model_forward):
return False return False
...@@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: ...@@ -75,7 +75,7 @@ def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
@overload @overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
... ...
...@@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...@@ -85,8 +85,8 @@ def is_vllm_model(model: object) -> TypeIs[VllmModel]:
def is_vllm_model( def is_vllm_model(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: ) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
return _check_vllm_model_init(model) and _check_vllm_model_forward(model) return _check_vllm_model_init(model) and _check_vllm_model_forward(model)
...@@ -105,7 +105,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): ...@@ -105,7 +105,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
@overload @overload
def is_text_generation_model( def is_text_generation_model(
model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
... ...
...@@ -116,8 +116,8 @@ def is_text_generation_model( ...@@ -116,8 +116,8 @@ def is_text_generation_model(
def is_text_generation_model( def is_text_generation_model(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]], ) -> Union[TypeIs[type[VllmModelForTextGeneration]],
TypeIs[VllmModelForTextGeneration]]: TypeIs[VllmModelForTextGeneration]]:
if not is_vllm_model(model): if not is_vllm_model(model):
return False return False
...@@ -142,7 +142,7 @@ class VllmModelForPooling(VllmModel[T], Protocol[T]): ...@@ -142,7 +142,7 @@ class VllmModelForPooling(VllmModel[T], Protocol[T]):
@overload @overload
def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
... ...
...@@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...@@ -152,8 +152,8 @@ def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
def is_pooling_model( def is_pooling_model(
model: Union[Type[object], object], model: Union[type[object], object],
) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: ) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
if not is_vllm_model(model): if not is_vllm_model(model):
return False return False
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Iterable, Optional, Set, Tuple from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -461,10 +462,10 @@ class InternVisionModel(nn.Module): ...@@ -461,10 +462,10 @@ class InternVisionModel(nn.Module):
return encoder_outputs return encoder_outputs
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Type, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -81,7 +82,7 @@ class InternLM2Attention(nn.Module): ...@@ -81,7 +82,7 @@ class InternLM2Attention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -225,7 +226,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -225,7 +226,7 @@ class InternLMDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -252,7 +253,7 @@ class InternLM2Model(nn.Module): ...@@ -252,7 +253,7 @@ class InternLM2Model(nn.Module):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer): layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -316,7 +317,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -316,7 +317,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model): model_type: type[InternLM2Model] = InternLM2Model):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -361,15 +362,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -361,15 +362,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0), ("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1), ("gate_up_proj", "w3", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -407,7 +408,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): ...@@ -407,7 +408,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
*, *,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model, model_type: type[InternLM2Model] = InternLM2Model,
): ):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple, Union from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -66,7 +66,7 @@ class InternLM2VEDecoderLayer(nn.Module): ...@@ -66,7 +66,7 @@ class InternLM2VEDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
visual_token_mask: Optional[torch.Tensor] = None, visual_token_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# -------------------------------------------------------- # --------------------------------------------------------
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union from typing import Literal, Optional, TypedDict, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -932,8 +932,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -932,8 +932,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
# unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B # unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
skip_prefixes = [ skip_prefixes = [
"action_embed", "temporal_embed", "track_embed", "action_embed", "temporal_embed", "track_embed",
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -333,10 +334,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -333,10 +334,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Jamba model.""" """Inference-only Jamba model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -442,7 +443,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -442,7 +443,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
conv_state_shape = ( conv_state_shape = (
...@@ -464,8 +465,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -464,8 +465,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -482,7 +483,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -482,7 +483,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
num_experts=self.config.num_experts) num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -583,7 +584,7 @@ class JambaForSequenceClassification(JambaForCausalLM): ...@@ -583,7 +584,7 @@ class JambaForSequenceClassification(JambaForCausalLM):
logits = self.score(hidden_states) logits = self.score(hidden_states)
return self._pooler(logits, pooling_metadata) return self._pooler(logits, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: The reward weights themselves have float32 accuracy data, we # TODO: The reward weights themselves have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision. # would like to load them in fp32 to get that extra precision.
super().load_weights(weights) super().load_weights(weights)
......
...@@ -43,10 +43,9 @@ ...@@ -43,10 +43,9 @@
import copy import copy
import math import math
from collections.abc import Mapping from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import (Any, Iterable, List, Literal, Optional, Sequence, Tuple, from typing import Any, Literal, Optional, TypedDict, Union
TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -120,7 +119,7 @@ class KimiVLMultiModalProjector(nn.Module): ...@@ -120,7 +119,7 @@ class KimiVLMultiModalProjector(nn.Module):
class KimiVLImagePixelInputs(TypedDict): class KimiVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, List[torch.Tensor]] pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape:`(num_patches, num_channels, patch_size, patch_size)` Shape:`(num_patches, num_channels, patch_size, patch_size)`
""" """
...@@ -447,7 +446,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -447,7 +446,7 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata, **kwargs) sampling_metadata, **kwargs)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head", "language_model.lm_head": "lm_head",
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -103,7 +104,7 @@ class LlamaAttention(nn.Module): ...@@ -103,7 +104,7 @@ class LlamaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -285,7 +286,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -285,7 +286,7 @@ class LlamaDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -394,8 +395,8 @@ class LlamaModel(nn.Module): ...@@ -394,8 +395,8 @@ class LlamaModel(nn.Module):
return hidden_states, aux_hidden_states return hidden_states, aux_hidden_states
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -405,7 +406,7 @@ class LlamaModel(nn.Module): ...@@ -405,7 +406,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -582,8 +583,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -582,8 +583,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
...@@ -599,7 +600,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -599,7 +600,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
name: str, name: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int): def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from collections.abc import Iterable
from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -48,7 +49,7 @@ class Llama4MoE(nn.Module): ...@@ -48,7 +49,7 @@ class Llama4MoE(nn.Module):
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk: int, topk: int,
renormalize: bool, renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1) router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
# psuedo-standard is that the router scores are floats # psuedo-standard is that the router scores are floats
router_scores = torch.sigmoid(router_scores.float()) router_scores = torch.sigmoid(router_scores.float())
...@@ -115,7 +116,7 @@ class Llama4Attention(nn.Module): ...@@ -115,7 +116,7 @@ class Llama4Attention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -300,7 +301,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -300,7 +301,7 @@ class Llama4DecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -335,9 +336,9 @@ class Llama4Model(LlamaModel): ...@@ -335,9 +336,9 @@ class Llama4Model(LlamaModel):
self, self,
name: str, name: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter], params_dict: dict[str, nn.Parameter],
loaded_params: Set[str], loaded_params: set[str],
expert_params_mapping: List[Tuple[str, str, int, str]], expert_params_mapping: list[tuple[str, str, int, str]],
fused: bool = True, fused: bool = True,
) -> bool: ) -> bool:
expert_param_loaded = False expert_param_loaded = False
...@@ -390,8 +391,8 @@ class Llama4Model(LlamaModel): ...@@ -390,8 +391,8 @@ class Llama4Model(LlamaModel):
expert_param_loaded = True expert_param_loaded = True
return expert_param_loaded return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -412,7 +413,7 @@ class Llama4Model(LlamaModel): ...@@ -412,7 +413,7 @@ class Llama4Model(LlamaModel):
ckpt_up_proj_name="gate_up_proj", ckpt_up_proj_name="gate_up_proj",
num_experts=1) num_experts=1)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name: if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True fused_experts_params = True
...@@ -489,8 +490,8 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -489,8 +490,8 @@ class Llama4ForCausalLM(LlamaForCausalLM):
prefix=prefix, prefix=prefix,
layer_type=layer_type) layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
...@@ -506,7 +507,7 @@ class Llama4ForCausalLM(LlamaForCausalLM): ...@@ -506,7 +507,7 @@ class Llama4ForCausalLM(LlamaForCausalLM):
self, self,
name: str, name: str,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]: ) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int): def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads attn_in = self.config.head_dim * n_heads
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Set, Tuple from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -92,8 +92,8 @@ class LlamaModel(nn.Module): ...@@ -92,8 +92,8 @@ class LlamaModel(nn.Module):
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
return hidden_states, hidden_states return hidden_states, hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -103,7 +103,7 @@ class LlamaModel(nn.Module): ...@@ -103,7 +103,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
...@@ -150,7 +150,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): ...@@ -150,7 +150,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states) return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=None, skip_prefixes=None,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -56,7 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): ...@@ -56,7 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
embeds: torch.Tensor, embeds: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states residual = hidden_states
embeds = self.input_layernorm(embeds) embeds = self.input_layernorm(embeds)
...@@ -140,8 +141,8 @@ class LlamaModel(nn.Module): ...@@ -140,8 +141,8 @@ class LlamaModel(nn.Module):
hidden_states, hidden_prenorm = self.norm(hidden_states, residual) hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm return hidden_states, hidden_prenorm
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -151,7 +152,7 @@ class LlamaModel(nn.Module): ...@@ -151,7 +152,7 @@ class LlamaModel(nn.Module):
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if 'midlayer.' in name: if 'midlayer.' in name:
name = name.replace('midlayer.', 'layers.0.') name = name.replace('midlayer.', 'layers.0.')
...@@ -228,7 +229,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ...@@ -228,7 +229,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
# combine multiple auxiliary hidden states returned by eagle3 # combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states) return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=None, skip_prefixes=None,
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
TypeVar, Union, cast) Union, cast)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -751,8 +751,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -751,8 +751,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod from abc import abstractmethod
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from collections.abc import Iterable, Mapping
Protocol, Set, Tuple, TypedDict, TypeVar, Union) from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -266,8 +267,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -266,8 +267,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return data return data
def _validate_pixel_values( def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -450,7 +451,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -450,7 +451,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input( def _process_image_input(
self, self,
image_input: LlavaNextImageInputs, image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return [image_input["data"]] return [image_input["data"]]
...@@ -577,7 +578,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -577,7 +578,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment