Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
...@@ -77,7 +77,6 @@ exclude = [ ...@@ -77,7 +77,6 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Snowflake Arctic model.""" """Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -458,8 +459,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -458,8 +459,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -467,8 +468,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -467,8 +468,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
mlp_params_mapping: List[Tuple[str, str, int]] = [] mlp_params_mapping: list[tuple[str, str, int]] = []
expert_params_mapping: List[Tuple[str, str, int]] = [] expert_params_mapping: list[tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers num_layers = self.config.num_hidden_layers
for layer in range(num_layers): for layer in range(num_layers):
...@@ -497,7 +498,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -497,7 +498,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
("ws", f"experts.{expert_id}.w3.weight", expert_id)) ("ws", f"experts.{expert_id}.w3.weight", expert_id))
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
logger.info( logger.info(
"It will take ~10 minutes loading from the 16-bit weights. " "It will take ~10 minutes loading from the 16-bit weights. "
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union from typing import Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -66,8 +66,8 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): ...@@ -66,8 +66,8 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
# Identity layer # Identity layer
self.post_layernorm = nn.Identity() self.post_layernorm = nn.Identity()
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -75,7 +75,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): ...@@ -75,7 +75,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
# NOTE: post_layernorm is not used in Aria # NOTE: post_layernorm is not used in Aria
...@@ -326,8 +326,8 @@ class AriaTextModel(LlamaModel, SupportsQuant): ...@@ -326,8 +326,8 @@ class AriaTextModel(LlamaModel, SupportsQuant):
# Adapted from LlamaModel.load_weights with the modification of adding # Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping` # the expert weights mapping to `stacked_params_mapping`
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -339,7 +339,7 @@ class AriaTextModel(LlamaModel, SupportsQuant): ...@@ -339,7 +339,7 @@ class AriaTextModel(LlamaModel, SupportsQuant):
("experts.w2_weight", "experts.fc2.weight", 'w2'), ("experts.w2_weight", "experts.fc2.weight", 'w2'),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -528,7 +528,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -528,7 +528,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vocab_size, logit_scale) self.vocab_size, logit_scale)
def _validate_image_sizes( def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]: self, images: list[torch.Tensor]) -> list[torch.Tensor]:
if not all(img.shape == images[0].shape for img in images): if not all(img.shape == images[0].shape for img in images):
raise ValueError("All images must be the same size") raise ValueError("All images must be the same size")
return images return images
...@@ -578,7 +578,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -578,7 +578,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def _process_image_input( def _process_image_input(
self, image_input: AriaImagePixelInputs self, image_input: AriaImagePixelInputs
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = image_input['pixel_values'] pixel_values = image_input['pixel_values']
...@@ -651,6 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -651,6 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# SPDX-License-Identifier: Apache-2.0 Adapted from # SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision # https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple, from collections.abc import Iterable, Mapping, Sequence
TypedDict, Union, cast) from typing import Literal, Optional, TypedDict, Union, cast
import torch import torch
from torch import nn from torch import nn
...@@ -315,8 +315,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -315,8 +315,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def dtype(self): def dtype(self):
return next(self.parameters()).dtype return next(self.parameters()).dtype
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -20,7 +20,8 @@ ...@@ -20,7 +20,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.""" """Inference-only BaiChuan model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -230,7 +231,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -230,7 +231,7 @@ class BaiChuanDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -320,15 +321,15 @@ class BaiChuanModel(nn.Module): ...@@ -320,15 +321,15 @@ class BaiChuanModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -421,8 +422,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -421,8 +422,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model.""" """Inference-only Bamba model."""
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -355,8 +356,8 @@ class BambaModel(nn.Module): ...@@ -355,8 +356,8 @@ class BambaModel(nn.Module):
hidden_states, _ = self.final_layernorm(hidden_states, residual) hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -367,7 +368,7 @@ class BambaModel(nn.Module): ...@@ -367,7 +368,7 @@ class BambaModel(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -495,7 +496,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -495,7 +496,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size hidden_size = self.config.hidden_size
...@@ -535,7 +536,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, ...@@ -535,7 +536,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch BART model.""" """PyTorch BART model."""
import math import math
from typing import Iterable, Optional, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -859,14 +860,14 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -859,14 +860,14 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
def _rename_stacked_param( def _rename_stacked_param(
self, self,
name: str, name: str,
) -> Tuple[str, Optional[str]]: ) -> tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items(): for key, mapping in self.stacked_params_mapping.items():
if key in name: if key in name:
name = name.replace(key, mapping["param_name"]) name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"] return name, mapping["shard_id"]
return name, None return name, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters()) model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters()) top_params_dict = dict(self.named_parameters())
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -349,8 +350,8 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -349,8 +350,8 @@ class BertModel(nn.Module, SupportsQuant):
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(hidden_states) return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"), ("qkv_proj", "query", "q"),
...@@ -359,7 +360,7 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -359,7 +360,7 @@ class BertModel(nn.Module, SupportsQuant):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if self.pooler is None and "pooler" in name: if self.pooler is None and "pooler" in name:
continue continue
...@@ -424,7 +425,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -424,7 +425,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
) -> Optional[PoolerOutput]: ) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights weights = ((name, data) for name, data in weights
if not name.startswith("lm_head.")) if not name.startswith("lm_head."))
...@@ -472,7 +473,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, ...@@ -472,7 +473,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self._pooler = CrossEncodingPooler(config, self.classifier, self._pooler = CrossEncodingPooler(config, self.classifier,
self.bert.pooler) self.bert.pooler)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = [] self_weights = []
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -208,7 +209,7 @@ class NomicRouter(nn.Module): ...@@ -208,7 +209,7 @@ class NomicRouter(nn.Module):
def forward( def forward(
self, x: torch.Tensor self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax( weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
dim=-1, dtype=torch.float32) dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1) top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
...@@ -428,8 +429,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -428,8 +429,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
return self.encoder(positions, hidden_states) return self.encoder(positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights) weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "geglu"]: if self.config.hidden_act in ["silu", "geglu"]:
...@@ -442,7 +443,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): ...@@ -442,7 +443,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
stacked_params_mapping = [] stacked_params_mapping = []
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "pooler" in name: if "pooler" in name:
continue continue
...@@ -567,7 +568,7 @@ class GteNewModel(BertWithRope): ...@@ -567,7 +568,7 @@ class GteNewModel(BertWithRope):
} }
return config return config
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]): def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj" n = "mlp.up_gate_proj"
for name, weight in weights: for name, weight in weights:
if n in name: if n in name:
...@@ -578,14 +579,14 @@ class GteNewModel(BertWithRope): ...@@ -578,14 +579,14 @@ class GteNewModel(BertWithRope):
yield name, weight yield name, weight
def ignore_unnecessary_layers(self, def ignore_unnecessary_layers(self,
weights: Iterable[Tuple[str, torch.Tensor]]): weights: Iterable[tuple[str, torch.Tensor]]):
for name, weight in weights: for name, weight in weights:
if name.startswith("classifier"): if name.startswith("classifier"):
continue continue
yield name, weight yield name, weight
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
weights = self.ignore_unnecessary_layers(weights) weights = self.ignore_unnecessary_layers(weights)
weights = self.split_up_gate_proj(weights) weights = self.split_up_gate_proj(weights)
return super().load_weights(weights) return super().load_weights(weights)
...@@ -664,7 +665,7 @@ class JinaRobertaModel(BertWithRope): ...@@ -664,7 +665,7 @@ class JinaRobertaModel(BertWithRope):
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
@torch.inference_mode() @torch.inference_mode()
def jina_merge_lora_weights(self, weights: Iterable[Tuple[str, def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]): torch.Tensor]]):
# use for jina-embeddings-v3 # use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor. # Merge Lora weights into a single weight tensor.
...@@ -707,7 +708,7 @@ class JinaRobertaModel(BertWithRope): ...@@ -707,7 +708,7 @@ class JinaRobertaModel(BertWithRope):
return [(name, weight) for name, weight in weights.items()] return [(name, weight) for name, weight in weights.items()]
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
weights = self.jina_merge_lora_weights(weights) weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights) return super().load_weights(weights)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of BlipVisionModel intended to be only used """Minimal implementation of BlipVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -296,8 +297,8 @@ class BlipVisionModel(nn.Module, SupportsQuant): ...@@ -296,8 +297,8 @@ class BlipVisionModel(nn.Module, SupportsQuant):
return self.post_layernorm(hidden_states) return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -305,7 +306,7 @@ class BlipVisionModel(nn.Module, SupportsQuant): ...@@ -305,7 +306,7 @@ class BlipVisionModel(nn.Module, SupportsQuant):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.encoder.layers) layer_count = len(self.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -186,7 +186,7 @@ class Blip2QFormerAttention(nn.Module): ...@@ -186,7 +186,7 @@ class Blip2QFormerAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_output = self.attention( self_output = self.attention(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -712,7 +712,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -712,7 +712,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.""" """Inference-only BLOOM model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -322,10 +323,10 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): ...@@ -322,10 +323,10 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if name == "lm_head.weight": if name == "lm_head.weight":
continue continue
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -229,7 +229,7 @@ class ChameleonAttention(nn.Module): ...@@ -229,7 +229,7 @@ class ChameleonAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
...@@ -292,7 +292,7 @@ class ChameleonAttention(nn.Module): ...@@ -292,7 +292,7 @@ class ChameleonAttention(nn.Module):
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor, def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# reshape for layernorm # reshape for layernorm
q = q.reshape(-1, self.num_heads, self.head_dim) q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_kv_heads, self.head_dim) k = k.reshape(-1, self.num_kv_heads, self.head_dim)
...@@ -367,7 +367,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -367,7 +367,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
...@@ -438,7 +438,7 @@ class ChameleonSwinDecoderLayer(nn.Module): ...@@ -438,7 +438,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn( hidden_states = self.self_attn(
...@@ -773,7 +773,7 @@ class ChameleonVQVAE(nn.Module): ...@@ -773,7 +773,7 @@ class ChameleonVQVAE(nn.Module):
def encode( def encode(
self, pixel_values: torch.Tensor self, pixel_values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.encoder(pixel_values) hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states) hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states) quant, emb_loss, indices = self.quantize(hidden_states)
...@@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping: ...@@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping:
A class for mapping discrete image tokens from VQGAN to BPE tokens. A class for mapping discrete image tokens from VQGAN to BPE tokens.
""" """
def __init__(self, vocab_map: Dict[str, int]): def __init__(self, vocab_map: dict[str, int]):
self.vocab_map = vocab_map self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>") self.image_token_id = vocab_map.get("<image>")
...@@ -1052,8 +1052,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1052,8 +1052,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1063,7 +1063,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1063,7 +1063,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
import json import json
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -358,15 +359,15 @@ class ChatGLMModel(nn.Module, SupportsQuant): ...@@ -358,15 +359,15 @@ class ChatGLMModel(nn.Module, SupportsQuant):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("linear_proj.merged_proj", "linear_proj.gate_proj", 0), ("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1), ("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
...@@ -440,7 +441,7 @@ class ChatGLMBaseModel(nn.Module): ...@@ -440,7 +441,7 @@ class ChatGLMBaseModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -368,8 +369,8 @@ class CLIPVisionModel(nn.Module, SupportsQuant): ...@@ -368,8 +369,8 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
# (TODO) Add prefix argument for filtering out weights to be loaded # (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -377,7 +378,7 @@ class CLIPVisionModel(nn.Module, SupportsQuant): ...@@ -377,7 +378,7 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers) layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
# This file is based on the LLama model definition file in transformers # This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model.""" """PyTorch Cohere model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -259,7 +260,7 @@ class CohereDecoderLayer(nn.Module): ...@@ -259,7 +260,7 @@ class CohereDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
...@@ -404,8 +405,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -404,8 +405,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -415,7 +416,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -415,7 +416,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
# Skip loading rotary embeddings since vLLM has its own # Skip loading rotary embeddings since vLLM has its own
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple from typing import Any
import torch import torch
...@@ -16,7 +16,7 @@ class ConstantSizeCache(ABC): ...@@ -16,7 +16,7 @@ class ConstantSizeCache(ABC):
def __init__(self, max_batch_size: int): def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id # Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache # and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.cache_indices_mapping: dict[str, dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size)) self.free_cache_indices = list(range(max_batch_size))
@property @property
...@@ -30,7 +30,7 @@ class ConstantSizeCache(ABC): ...@@ -30,7 +30,7 @@ class ConstantSizeCache(ABC):
"""Copy cache data from one index to another""" """Copy cache data from one index to another"""
pass pass
def current_run_tensors(self, **kwargs) -> Tuple: def current_run_tensors(self, **kwargs) -> tuple:
""" """
Return the tensors for the current run's conv and ssm state. Return the tensors for the current run's conv and ssm state.
""" """
...@@ -117,8 +117,8 @@ class ConstantSizeCache(ABC): ...@@ -117,8 +117,8 @@ class ConstantSizeCache(ABC):
return self.cache_indices_mapping[cur_rid][seq_id] return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache( def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], self, request_ids_to_seq_ids: dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]: finished_requests_ids: list[str]) -> list[int]:
return [ return [
self._assign_seq_id_to_cache_index(req_id, seq_id, self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids) finished_requests_ids)
...@@ -127,7 +127,7 @@ class ConstantSizeCache(ABC): ...@@ -127,7 +127,7 @@ class ConstantSizeCache(ABC):
] ]
def _release_finished_requests(self, def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]): finished_seq_groups_req_ids: list[str]):
for req_id in finished_seq_groups_req_ids: for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping: if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]: for seq_id in self.cache_indices_mapping[req_id]:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -414,14 +415,14 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -414,14 +415,14 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
expert_params_mapping = [( expert_params_mapping = [(
"w13" if weight_name in ["w1", "v1"] else "w2", "w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}", f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]] ) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Deepseek model.""" """Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -184,7 +185,7 @@ class DeepseekAttention(nn.Module): ...@@ -184,7 +185,7 @@ class DeepseekAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -385,8 +386,8 @@ class DeepseekModel(nn.Module): ...@@ -385,8 +386,8 @@ class DeepseekModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -397,7 +398,7 @@ class DeepseekModel(nn.Module): ...@@ -397,7 +398,7 @@ class DeepseekModel(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -478,7 +479,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -478,7 +479,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -176,8 +177,8 @@ class DeepSeekMTP(nn.Module): ...@@ -176,8 +177,8 @@ class DeepSeekMTP(nn.Module):
return self.model.compute_logits(hidden_states, sampling_metadata, return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx) spec_step_idx)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
...@@ -190,7 +191,7 @@ class DeepSeekMTP(nn.Module): ...@@ -190,7 +191,7 @@ class DeepSeekMTP(nn.Module):
num_experts=self.config.n_routed_experts) num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment