Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
......@@ -2,7 +2,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -35,7 +35,7 @@ from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
......@@ -300,8 +300,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model.model.make_empty_intermediate_tensors)
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -326,7 +326,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
list[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
......@@ -460,8 +460,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# This model doesn't support images for now
......
......@@ -2,8 +2,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union)
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
import torch
import torch.nn as nn
......@@ -471,8 +470,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -530,8 +529,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -557,7 +556,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
list[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
......@@ -706,7 +705,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
......@@ -735,7 +734,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self,
image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
......@@ -948,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -30,7 +31,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
KVCache = tuple[torch.Tensor, torch.Tensor]
class MambaDecoderLayer(nn.Module):
......@@ -153,10 +154,10 @@ class MambaModel(nn.Module):
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
......@@ -247,7 +248,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape = (
self.config.intermediate_size // world_size,
......@@ -265,7 +266,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
......@@ -35,7 +36,7 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor]
KVCache = tuple[torch.Tensor, torch.Tensor]
class Mamba2DecoderLayer(nn.Module):
......@@ -241,7 +242,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
......@@ -279,10 +280,10 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Tuple
import torch
......@@ -25,8 +24,8 @@ class MambaCacheParams:
class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int]):
num_mamba_layers: int, conv_state_shape: tuple[int, int],
temporal_state_shape: tuple[int, int]):
# Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs
......
# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
......@@ -96,13 +97,13 @@ class Medusa(nn.Module):
# checkpoint file has token_map tensor.
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits_lst: List[torch.Tensor] = []
self, hidden_states: list[torch.Tensor],
sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
logits_lst: list[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
......@@ -127,9 +128,9 @@ class Medusa(nn.Module):
def sample(
self,
logits: List[torch.Tensor],
logits: list[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
......@@ -144,7 +145,7 @@ class Medusa(nn.Module):
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: List[Optional[SamplerOutput]] = []
outputs: list[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
......@@ -160,7 +161,7 @@ class Medusa(nn.Module):
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
......@@ -169,10 +170,10 @@ class Medusa(nn.Module):
sampling_metadata=sampling_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
weights_map = {}
......
......@@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiMo model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
......@@ -87,8 +88,8 @@ class MiMoModel(Qwen2Model):
hidden_states = hidden_states + residual
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
......@@ -97,7 +98,7 @@ class MiMoModel(Qwen2Model):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "mtp_layers" in name:
continue
......
......@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiMo-MTP model."""
from typing import Iterable, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
......@@ -193,8 +194,8 @@ class MiMoMTP(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
......@@ -204,7 +205,7 @@ class MiMoMTP(nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
......
......@@ -23,7 +23,8 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Any, Optional, Union
import torch
from torch import nn
......@@ -190,7 +191,7 @@ class MiniCPMAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......@@ -329,7 +330,7 @@ class MiniCPMDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
......@@ -428,8 +429,8 @@ class MiniCPMModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -446,7 +447,7 @@ class MiniCPMModel(nn.Module):
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......@@ -582,8 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
......
......@@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch
from torch import nn
......@@ -58,7 +58,7 @@ class MiniCPM3Attention(nn.Module):
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
......
......@@ -23,8 +23,7 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
from typing import Any, Callable, Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -559,8 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.audio_encoder_layer = -1
return model
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
return loader.load_weights(weights)
......
......@@ -26,8 +26,7 @@ import math
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
from typing import Any, Callable, Literal, Optional, TypedDict, Union
import numpy as np
import torch
......@@ -118,7 +117,7 @@ class Resampler2_5(BaseResampler):
num_heads: int,
kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70),
max_size: tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(num_queries,
......@@ -133,7 +132,7 @@ class Resampler2_5(BaseResampler):
self._set_2d_pos_cache(self.max_size)
def _set_2d_pos_cache(self,
max_size: Tuple[int, int],
max_size: tuple[int, int],
device: torch.types.Device = "cpu") -> None:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
max_size,
......@@ -203,7 +202,7 @@ class Resampler2_5(BaseResampler):
return x
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
version_float = getattr(config, "version", None)
# The old configs do not include version number
......@@ -938,8 +937,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
......@@ -3,7 +3,8 @@
import copy
import math
import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.distributed
......@@ -127,7 +128,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
......@@ -178,7 +179,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
......@@ -708,11 +709,11 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[List[Dict], Optional[torch.Tensor]],
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
......@@ -1072,10 +1073,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
def which_layer(name: str) -> int:
if "layers" in name:
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast
from typing import Literal, Optional, TypedDict, Union, cast
import torch
import torch.nn as nn
......@@ -357,7 +357,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......@@ -2,8 +2,8 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
TypeVar, Union)
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
Union)
import torch
import torch.nn as nn
......@@ -589,8 +589,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
......
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
......@@ -314,8 +315,8 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -332,7 +333,7 @@ class MixtralModel(nn.Module):
num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
......@@ -479,7 +480,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
return loader.load_weights(weights)
......@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import Iterable, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import numpy as np
import torch
......@@ -397,8 +398,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
......@@ -407,7 +408,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
......
......@@ -16,7 +16,7 @@
"""PyTorch Mllama model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import numpy as np
import torch
......@@ -224,7 +224,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
return mm_inputs
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int:
def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index:
......@@ -370,8 +370,8 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
kernel_size: Union[int, tuple[int, int]],
stride: Union[int, tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
......@@ -603,7 +603,7 @@ class MllamaVisionEncoder(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
) -> Union[BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
......@@ -878,7 +878,7 @@ class MllamaTextCrossAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
kv_range_for_decode: Optional[list[tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
......@@ -905,7 +905,7 @@ class MllamaTextCrossAttention(nn.Module):
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]],
kv_range_for_decode: list[tuple[int, int]],
) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
......@@ -1019,7 +1019,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]],
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
......@@ -1089,8 +1089,8 @@ class MllamaTextModel(nn.Module):
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
......@@ -1150,8 +1150,8 @@ class MllamaForCausalLM(nn.Module):
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
......@@ -1221,7 +1221,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits
def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor],
image_data: Union[list[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor):
# torch.Tensor
......@@ -1230,7 +1230,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
assert isinstance(
image_data[0],
torch.Tensor), "Image data is not properly batched."
# List[torch.Tensor]
# list[torch.Tensor]
bsz = len(image_data)
max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:]
......@@ -1248,24 +1248,24 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - List[torch.Tensor]:
# - list[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None)
......@@ -1293,10 +1293,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: List[int],
num_tiles: List[List[int]],
encoder_seq_lens: list[int],
num_tiles: list[list[int]],
num_tokens_per_tile: int,
) -> List[int]:
) -> list[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
......@@ -1318,7 +1318,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]):
actual_encoder_seq_lens: list[int]):
cross_attention_states_flat = torch.zeros(
sum(actual_encoder_seq_lens),
......@@ -1342,8 +1342,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int],
) -> Tuple[torch.Tensor]:
actual_encoder_seq_lens: list[int],
) -> tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
......@@ -1367,10 +1367,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: List[List[int]],
num_tiles: list[list[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
......@@ -1422,7 +1422,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
input_ids: torch.Tensor,
positions: torch.Tensor,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
) -> Union[CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
......@@ -1476,8 +1476,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
return outputs
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
......@@ -1487,7 +1487,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
......@@ -1538,7 +1538,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
tower_model="vision_model")
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
def skip_attention_mask(sparse_mask: list[list[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
......@@ -1556,10 +1556,10 @@ def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
sparse_mask: list[list[list[int]]],
num_tiles: list[list[int]],
lengths: list[int],
) -> tuple[np.ndarray, list[tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
......
......@@ -18,7 +18,7 @@
import math
from collections.abc import Iterable, Mapping
from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -582,7 +582,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> List[PromptUpdate]:
) -> list[PromptUpdate]:
assert (
mm_items.get_count("image", strict=False) == 0
or "aspect_ratios" in out_mm_kwargs
......@@ -778,26 +778,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def separate_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
prefix: str,
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights1:
if name.startswith(prefix):
yield (name, data)
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights2:
if not name.startswith(prefix):
yield (name, data)
return get_prefix_weights(), get_other_weights()
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......@@ -806,7 +806,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
updated_params: set[str] = set()
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
......
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Iterable, List, Set, Tuple
from collections.abc import Iterable
import torch
import torch.nn as nn
......@@ -148,7 +148,7 @@ class MLPSpeculator(nn.Module):
previous_hidden_states: torch.Tensor,
num_predict_tokens: int,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but "
......@@ -190,10 +190,10 @@ class MLPSpeculator(nn.Module):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
name = name.replace("speculator.", "")
param = params_dict.get(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment