Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
......@@ -77,7 +77,6 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"]
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -458,8 +459,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -467,8 +468,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
("qkv_proj", "v_proj", "v"),
]
mlp_params_mapping: List[Tuple[str, str, int]] = []
expert_params_mapping: List[Tuple[str, str, int]] = []
mlp_params_mapping: list[tuple[str, str, int]] = []
expert_params_mapping: list[tuple[str, str, int]] = []
num_layers = self.config.num_hidden_layers
for layer in range(num_layers):
......@@ -497,7 +498,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
("ws", f"experts.{expert_id}.w3.weight", expert_id))
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
from typing import Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -66,8 +66,8 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
# Identity layer
self.post_layernorm = nn.Identity()
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -75,7 +75,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# NOTE: post_layernorm is not used in Aria
......@@ -326,8 +326,8 @@ class AriaTextModel(LlamaModel, SupportsQuant):
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -339,7 +339,7 @@ class AriaTextModel(LlamaModel, SupportsQuant):
("experts.w2_weight", "experts.fc2.weight", 'w2'),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -528,7 +528,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vocab_size, logit_scale)
def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
self, images: list[torch.Tensor]) -> list[torch.Tensor]:
if not all(img.shape == images[0].shape for img in images):
raise ValueError("All images must be the same size")
return images
......@@ -578,7 +578,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def _process_image_input(
self, image_input: AriaImagePixelInputs
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input['pixel_values']
......@@ -651,6 +651,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union, cast)
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union, cast
import torch
from torch import nn
......@@ -315,8 +315,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def dtype(self):
return next(self.parameters()).dtype
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
......@@ -20,7 +20,8 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -230,7 +231,7 @@ class BaiChuanDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
......@@ -320,15 +321,15 @@ class BaiChuanModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -421,8 +422,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
# SPDX-License-Identifier: Apache-2.0
"""Inference-only Bamba model."""
# Added by the IBM Team, 2024
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -355,8 +356,8 @@ class BambaModel(nn.Module):
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -367,7 +368,7 @@ class BambaModel(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -495,7 +496,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
......@@ -535,7 +536,7 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -19,7 +19,8 @@
# limitations under the License.
"""PyTorch BART model."""
import math
from typing import Iterable, Optional, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -859,14 +860,14 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
def _rename_stacked_param(
self,
name: str,
) -> Tuple[str, Optional[str]]:
) -> tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -349,8 +350,8 @@ class BertModel(nn.Module, SupportsQuant):
token_type_ids=token_type_ids)
return self.encoder(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
......@@ -359,7 +360,7 @@ class BertModel(nn.Module, SupportsQuant):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.pooler is None and "pooler" in name:
continue
......@@ -424,7 +425,7 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
......@@ -472,7 +473,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
self._pooler = CrossEncodingPooler(config, self.classifier,
self.bert.pooler)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -208,7 +209,7 @@ class NomicRouter(nn.Module):
def forward(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
dim=-1, dtype=torch.float32)
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
......@@ -428,8 +429,8 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
token_type_ids=token_type_ids)
return self.encoder(positions, hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
if self.config.hidden_act in ["silu", "geglu"]:
......@@ -442,7 +443,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant):
stacked_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
continue
......@@ -567,7 +568,7 @@ class GteNewModel(BertWithRope):
}
return config
def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def split_up_gate_proj(self, weights: Iterable[tuple[str, torch.Tensor]]):
n = "mlp.up_gate_proj"
for name, weight in weights:
if n in name:
......@@ -578,14 +579,14 @@ class GteNewModel(BertWithRope):
yield name, weight
def ignore_unnecessary_layers(self,
weights: Iterable[Tuple[str, torch.Tensor]]):
weights: Iterable[tuple[str, torch.Tensor]]):
for name, weight in weights:
if name.startswith("classifier"):
continue
yield name, weight
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.ignore_unnecessary_layers(weights)
weights = self.split_up_gate_proj(weights)
return super().load_weights(weights)
......@@ -664,7 +665,7 @@ class JinaRobertaModel(BertWithRope):
token_type_ids=token_type_ids)
@torch.inference_mode()
def jina_merge_lora_weights(self, weights: Iterable[Tuple[str,
def jina_merge_lora_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]):
# use for jina-embeddings-v3
# Merge Lora weights into a single weight tensor.
......@@ -707,7 +708,7 @@ class JinaRobertaModel(BertWithRope):
return [(name, weight) for name, weight in weights.items()]
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.jina_merge_lora_weights(weights)
return super().load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -296,8 +297,8 @@ class BlipVisionModel(nn.Module, SupportsQuant):
return self.post_layernorm(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -305,7 +306,7 @@ class BlipVisionModel(nn.Module, SupportsQuant):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
for name, loaded_weight in weights:
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -186,7 +186,7 @@ class Blip2QFormerAttention(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor]:
self_output = self.attention(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -712,7 +712,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -18,7 +18,8 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -322,10 +323,10 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
......
......@@ -2,7 +2,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -229,7 +229,7 @@ class ChameleonAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
......@@ -292,7 +292,7 @@ class ChameleonAttention(nn.Module):
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# reshape for layernorm
q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
......@@ -367,7 +367,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states
......@@ -438,7 +438,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn(
......@@ -773,7 +773,7 @@ class ChameleonVQVAE(nn.Module):
def encode(
self, pixel_values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
......@@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping:
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def __init__(self, vocab_map: Dict[str, int]):
def __init__(self, vocab_map: dict[str, int]):
self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>")
......@@ -1052,8 +1052,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -1063,7 +1063,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -3,7 +3,8 @@
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
import json
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -358,15 +359,15 @@ class ChatGLMModel(nn.Module, SupportsQuant):
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
......@@ -440,7 +441,7 @@ class ChatGLMBaseModel(nn.Module):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
......
# SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -368,8 +369,8 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -377,7 +378,7 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
......
......@@ -21,7 +21,8 @@
# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -259,7 +260,7 @@ class CohereDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states, residual = self.input_layernorm(hidden_states, residual)
......@@ -404,8 +405,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -415,7 +416,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Skip loading rotary embeddings since vLLM has its own
......
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Any
import torch
......@@ -16,7 +16,7 @@ class ConstantSizeCache(ABC):
def __init__(self, max_batch_size: int):
# Maps between the request id and a dict that maps between the seq_id
# and its index inside the cache
self.cache_indices_mapping: Dict[str, Dict[int, int]] = {}
self.cache_indices_mapping: dict[str, dict[int, int]] = {}
self.free_cache_indices = list(range(max_batch_size))
@property
......@@ -30,7 +30,7 @@ class ConstantSizeCache(ABC):
"""Copy cache data from one index to another"""
pass
def current_run_tensors(self, **kwargs) -> Tuple:
def current_run_tensors(self, **kwargs) -> tuple:
"""
Return the tensors for the current run's conv and ssm state.
"""
......@@ -117,8 +117,8 @@ class ConstantSizeCache(ABC):
return self.cache_indices_mapping[cur_rid][seq_id]
def _prepare_current_run_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str]) -> List[int]:
self, request_ids_to_seq_ids: dict[str, list[int]],
finished_requests_ids: list[str]) -> list[int]:
return [
self._assign_seq_id_to_cache_index(req_id, seq_id,
finished_requests_ids)
......@@ -127,7 +127,7 @@ class ConstantSizeCache(ABC):
]
def _release_finished_requests(self,
finished_seq_groups_req_ids: List[str]):
finished_seq_groups_req_ids: list[str]):
for req_id in finished_seq_groups_req_ids:
if req_id in self.cache_indices_mapping:
for seq_id in self.cache_indices_mapping[req_id]:
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -414,14 +415,14 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
expert_params_mapping = [(
"w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -184,7 +185,7 @@ class DeepseekAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -385,8 +386,8 @@ class DeepseekModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -397,7 +398,7 @@ class DeepseekModel(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -478,7 +479,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
......@@ -176,8 +177,8 @@ class DeepSeekMTP(nn.Module):
return self.model.compute_logits(hidden_states, sampling_metadata,
spec_step_idx)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
......@@ -190,7 +191,7 @@ class DeepSeekMTP(nn.Module):
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment