Unverified Commit 26d04193 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `models` (#18132)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 83f74c69
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -35,7 +35,7 @@ from .vision import get_vision_encoder_info ...@@ -35,7 +35,7 @@ from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TypedDict): class LlavaNextVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
data: Union[torch.Tensor, List[torch.Tensor]] data: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size, num_frames, num_channels, height, width)` Shape: `(batch_size, num_frames, num_channels, height, width)`
...@@ -300,8 +300,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -300,8 +300,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -326,7 +326,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -326,7 +326,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
A legal video input should have the following dimensions: A legal video input should have the following dimensions:
{ {
"pixel_values_videos" : "pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)] list[b, Tensor(nb_frames, nb_channels, height, width)]
} }
""" """
pixel_values_videos = kwargs.pop("pixel_values_videos", None) pixel_values_videos = kwargs.pop("pixel_values_videos", None)
...@@ -460,8 +460,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -460,8 +460,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
# This model doesn't support images for now # This model doesn't support images for now
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple, from typing import Final, Literal, Optional, Protocol, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -471,8 +470,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -471,8 +470,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return data return data
def _validate_image_pixel_values( def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -530,8 +529,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -530,8 +529,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -557,7 +556,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -557,7 +556,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
A legal video input should have the following dimensions: A legal video input should have the following dimensions:
{ {
"pixel_values_videos" : "pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)] list[b, Tensor(nb_frames, nb_channels, height, width)]
} }
""" """
pixel_values_videos = kwargs.pop("pixel_values_videos", None) pixel_values_videos = kwargs.pop("pixel_values_videos", None)
...@@ -706,7 +705,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -706,7 +705,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels( def _process_image_pixels(
self, self,
inputs: LlavaOnevisionImagePixelInputs, inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = inputs["pixel_values"] pixel_values = inputs["pixel_values"]
...@@ -735,7 +734,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -735,7 +734,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input( def _process_image_input(
self, self,
image_input: LlavaOnevisionImageInputs, image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]: ) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds": if image_input["type"] == "image_embeds":
return [image_input["data"]] return [image_input["data"]]
...@@ -948,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -948,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA model.""" """PyTorch MAMBA model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -30,7 +31,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter, ...@@ -30,7 +31,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = tuple[torch.Tensor, torch.Tensor]
class MambaDecoderLayer(nn.Module): class MambaDecoderLayer(nn.Module):
...@@ -153,10 +154,10 @@ class MambaModel(nn.Module): ...@@ -153,10 +154,10 @@ class MambaModel(nn.Module):
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
...@@ -247,7 +248,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -247,7 +248,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
conv_state_shape = ( conv_state_shape = (
self.config.intermediate_size // world_size, self.config.intermediate_size // world_size,
...@@ -265,7 +266,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP, ...@@ -265,7 +266,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""PyTorch MAMBA2 model.""" """PyTorch MAMBA2 model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
from torch import nn from torch import nn
...@@ -35,7 +36,7 @@ from .utils import (is_pp_missing_parameter, ...@@ -35,7 +36,7 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = tuple[torch.Tensor, torch.Tensor]
class Mamba2DecoderLayer(nn.Module): class Mamba2DecoderLayer(nn.Module):
...@@ -241,7 +242,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -241,7 +242,7 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self) -> Tuple[Tuple[int, int], Tuple[int, int]]: self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None conv_state_shape, temporal_state_shape = None, None
...@@ -279,10 +280,10 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, ...@@ -279,10 +280,10 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "A_log" in name: if "A_log" in name:
name = name.replace("A_log", "A") name = name.replace("A_log", "A")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple
import torch import torch
...@@ -25,8 +24,8 @@ class MambaCacheParams: ...@@ -25,8 +24,8 @@ class MambaCacheParams:
class MambaCacheManager(ConstantSizeCache): class MambaCacheManager(ConstantSizeCache):
def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype,
num_mamba_layers: int, conv_state_shape: Tuple[int, int], num_mamba_layers: int, conv_state_shape: tuple[int, int],
temporal_state_shape: Tuple[int, int]): temporal_state_shape: tuple[int, int]):
# Determine max batch size to set size of MambaCache # Determine max batch size to set size of MambaCache
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Iterable, List, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -96,13 +97,13 @@ class Medusa(nn.Module): ...@@ -96,13 +97,13 @@ class Medusa(nn.Module):
# checkpoint file has token_map tensor. # checkpoint file has token_map tensor.
self.token_map = None self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
return [block(hidden_states) for block in self.blocks] return [block(hidden_states) for block in self.blocks]
def compute_logits( def compute_logits(
self, hidden_states: List[torch.Tensor], self, hidden_states: list[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: sampling_metadata: SamplingMetadata) -> list[torch.Tensor]:
logits_lst: List[torch.Tensor] = [] logits_lst: list[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads): for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata) _logits = self.logits_processor(lm_head, hs, sampling_metadata)
...@@ -127,9 +128,9 @@ class Medusa(nn.Module): ...@@ -127,9 +128,9 @@ class Medusa(nn.Module):
def sample( def sample(
self, self,
logits: List[torch.Tensor], logits: list[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
logits = torch.stack(logits, dim=0).float() logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1) logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now token_ids = logits.argmax(-1) # support only top-1 for now
...@@ -144,7 +145,7 @@ class Medusa(nn.Module): ...@@ -144,7 +145,7 @@ class Medusa(nn.Module):
token_prob_list.append(probs[:, seq_group.sample_indices]) token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices]) token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: List[Optional[SamplerOutput]] = [] outputs: list[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)): for idx in range(len(sampling_metadata.seq_groups)):
outputs.append( outputs.append(
SamplerOutput( SamplerOutput(
...@@ -160,7 +161,7 @@ class Medusa(nn.Module): ...@@ -160,7 +161,7 @@ class Medusa(nn.Module):
self, self,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
return self.sample( return self.sample(
logits=self.compute_logits( logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states), hidden_states=self.forward(previous_hidden_states),
...@@ -169,10 +170,10 @@ class Medusa(nn.Module): ...@@ -169,10 +170,10 @@ class Medusa(nn.Module):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
weights_map = {} weights_map = {}
......
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiMo model compatible with HuggingFace weights.""" """Inference-only MiMo model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -87,8 +88,8 @@ class MiMoModel(Qwen2Model): ...@@ -87,8 +88,8 @@ class MiMoModel(Qwen2Model):
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
...@@ -97,7 +98,7 @@ class MiMoModel(Qwen2Model): ...@@ -97,7 +98,7 @@ class MiMoModel(Qwen2Model):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "mtp_layers" in name: if "mtp_layers" in name:
continue continue
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiMo-MTP model.""" """Inference-only MiMo-MTP model."""
from typing import Iterable, Optional, Set, Tuple from collections.abc import Iterable
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -193,8 +194,8 @@ class MiMoMTP(nn.Module): ...@@ -193,8 +194,8 @@ class MiMoMTP(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
...@@ -204,7 +205,7 @@ class MiMoMTP(nn.Module): ...@@ -204,7 +205,7 @@ class MiMoMTP(nn.Module):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
......
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -190,7 +191,7 @@ class MiniCPMAttention(nn.Module): ...@@ -190,7 +191,7 @@ class MiniCPMAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
...@@ -329,7 +330,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -329,7 +330,7 @@ class MiniCPMDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
...@@ -428,8 +429,8 @@ class MiniCPMModel(nn.Module): ...@@ -428,8 +429,8 @@ class MiniCPMModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -446,7 +447,7 @@ class MiniCPMModel(nn.Module): ...@@ -446,7 +447,7 @@ class MiniCPMModel(nn.Module):
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -582,8 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -582,8 +583,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,
skip_prefixes=(["lm_head."] skip_prefixes=(["lm_head."]
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" """Inference-only MiniCPM3 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -58,7 +58,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -58,7 +58,7 @@ class MiniCPM3Attention(nn.Module):
q_lora_rank: int, q_lora_rank: int,
kv_lora_rank: int, kv_lora_rank: int,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
......
...@@ -23,8 +23,7 @@ ...@@ -23,8 +23,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import Any, Callable, Literal, Optional, TypedDict, Union
Union)
import torch import torch
from torch import nn from torch import nn
...@@ -559,8 +558,8 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -559,8 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.audio_encoder_layer = -1 self.audio_encoder_layer = -1
return model return model
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["tts"]) loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -26,8 +26,7 @@ import math ...@@ -26,8 +26,7 @@ import math
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict, from typing import Any, Callable, Literal, Optional, TypedDict, Union
Union)
import numpy as np import numpy as np
import torch import torch
...@@ -118,7 +117,7 @@ class Resampler2_5(BaseResampler): ...@@ -118,7 +117,7 @@ class Resampler2_5(BaseResampler):
num_heads: int, num_heads: int,
kv_dim: Optional[int] = None, kv_dim: Optional[int] = None,
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
max_size: Tuple[int, int] = (70, 70), max_size: tuple[int, int] = (70, 70),
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None: prefix: str = "") -> None:
super().__init__(num_queries, super().__init__(num_queries,
...@@ -133,7 +132,7 @@ class Resampler2_5(BaseResampler): ...@@ -133,7 +132,7 @@ class Resampler2_5(BaseResampler):
self._set_2d_pos_cache(self.max_size) self._set_2d_pos_cache(self.max_size)
def _set_2d_pos_cache(self, def _set_2d_pos_cache(self,
max_size: Tuple[int, int], max_size: tuple[int, int],
device: torch.types.Device = "cpu") -> None: device: torch.types.Device = "cpu") -> None:
pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim,
max_size, max_size,
...@@ -203,7 +202,7 @@ class Resampler2_5(BaseResampler): ...@@ -203,7 +202,7 @@ class Resampler2_5(BaseResampler):
return x return x
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
version_float = getattr(config, "version", None) version_float = getattr(config, "version", None)
# The old configs do not include version number # The old configs do not include version number
...@@ -938,8 +937,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -938,8 +937,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata) return self.llm.compute_logits(hidden_states, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import copy import copy
import math import math
import re import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
import torch.distributed import torch.distributed
...@@ -127,7 +128,7 @@ class MiniMaxText01RMSNormTP(CustomOp): ...@@ -127,7 +128,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection." assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x) return self._forward(x)
...@@ -178,7 +179,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp): ...@@ -178,7 +179,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device) self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype) query_cast = query.to(self.cache_dtype)
...@@ -708,11 +709,11 @@ class MiniMaxText01DecoderLayer(nn.Module): ...@@ -708,11 +709,11 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self, def forward(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: Union[List[Dict], Optional[torch.Tensor]], kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
is_warmup: bool = False, is_warmup: bool = False,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]: **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
...@@ -1072,10 +1073,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, ...@@ -1072,10 +1073,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
device=device), device=device),
}) })
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
def which_layer(name: str) -> int: def which_layer(name: str) -> int:
if "layers" in name: if "layers" in name:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Literal, Optional, Set, Tuple, TypedDict, Union, cast from typing import Literal, Optional, TypedDict, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -357,7 +357,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -357,7 +357,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict, from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
TypeVar, Union) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -589,8 +589,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, ...@@ -589,8 +589,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
......
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
...@@ -314,8 +315,8 @@ class MixtralModel(nn.Module): ...@@ -314,8 +315,8 @@ class MixtralModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -332,7 +333,7 @@ class MixtralModel(nn.Module): ...@@ -332,7 +333,7 @@ class MixtralModel(nn.Module):
num_experts=self.config.num_local_experts) num_experts=self.config.num_local_experts)
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if (self.quant_config is not None and if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))): (scale_name := self.quant_config.get_cache_scale(name))):
...@@ -479,7 +480,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -479,7 +480,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"]) loader = AutoWeightsLoader(self, skip_prefixes=["rotary_emb.inv_freq"])
return loader.load_weights(weights) return loader.load_weights(weights)
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, Optional, Set, Tuple, Union from collections.abc import Iterable
from typing import Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -397,8 +398,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -397,8 +398,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
sampling_metadata) sampling_metadata)
return logits return logits
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
...@@ -407,7 +408,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -407,7 +408,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -224,7 +224,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -224,7 +224,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
return mm_inputs return mm_inputs
def _get_num_image_in_last_group(self, prompt_token_ids: List[int]) -> int: def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int:
num_images = 0 num_images = 0
for token_id in prompt_token_ids[::-1]: for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index: if token_id == self.info.get_hf_config().image_token_index:
...@@ -370,8 +370,8 @@ class ColumnParallelConv2dPatch(torch.nn.Module): ...@@ -370,8 +370,8 @@ class ColumnParallelConv2dPatch(torch.nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: Union[int, Tuple[int, int]], kernel_size: Union[int, tuple[int, int]],
stride: Union[int, Tuple[int, int]], stride: Union[int, tuple[int, int]],
bias: bool = False, bias: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -603,7 +603,7 @@ class MllamaVisionEncoder(nn.Module): ...@@ -603,7 +603,7 @@ class MllamaVisionEncoder(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]: ) -> Union[BaseModelOutput]:
encoder_states = () encoder_states = ()
for i, encoder_layer in enumerate(self.layers): for i, encoder_layer in enumerate(self.layers):
...@@ -878,7 +878,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -878,7 +878,7 @@ class MllamaTextCrossAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[list[tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor], cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states) q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
...@@ -905,7 +905,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -905,7 +905,7 @@ class MllamaTextCrossAttention(nn.Module):
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
kv_range_for_decode: List[Tuple[int, int]], kv_range_for_decode: list[tuple[int, int]],
) -> torch.Tensor: ) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank] kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
...@@ -1019,7 +1019,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -1019,7 +1019,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor, cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor, cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
...@@ -1089,8 +1089,8 @@ class MllamaTextModel(nn.Module): ...@@ -1089,8 +1089,8 @@ class MllamaTextModel(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1150,8 +1150,8 @@ class MllamaForCausalLM(nn.Module): ...@@ -1150,8 +1150,8 @@ class MllamaForCausalLM(nn.Module):
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[Tuple[int, int]]], kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]], torch.Tensor]],
skip_cross_attention: bool, skip_cross_attention: bool,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1221,7 +1221,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1221,7 +1221,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits return logits
def unpack_data(self, def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor], image_data: Union[list[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor: padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor): if isinstance(image_data, torch.Tensor):
# torch.Tensor # torch.Tensor
...@@ -1230,7 +1230,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1230,7 +1230,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
assert isinstance( assert isinstance(
image_data[0], image_data[0],
torch.Tensor), "Image data is not properly batched." torch.Tensor), "Image data is not properly batched."
# List[torch.Tensor] # list[torch.Tensor]
bsz = len(image_data) bsz = len(image_data)
max_length = max(t.size(0) for t in image_data) max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:] trailing_dims = image_data[0].shape[1:]
...@@ -1248,24 +1248,24 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1248,24 +1248,24 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input(self, **kwargs: object): def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by # tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be: # MultiModalKwargs.batch, so pixel_values here can be:
# - List[torch.Tensor]: # - list[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res) # with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor: # - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res) # with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[List[List[torch.Tensor]], pixel_values: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"pixel_values", None) "pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]], image_embeds: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"image_embeds", None) "image_embeds", None)
aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None) "aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]],
List[torch.Tensor], list[torch.Tensor],
torch.Tensor]] = kwargs.pop( torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None) "aspect_ratio_mask", None)
...@@ -1293,10 +1293,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1293,10 +1293,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def _get_and_validate_encoder_lens( def _get_and_validate_encoder_lens(
self, self,
encoder_seq_lens: List[int], encoder_seq_lens: list[int],
num_tiles: List[List[int]], num_tiles: list[list[int]],
num_tokens_per_tile: int, num_tokens_per_tile: int,
) -> List[int]: ) -> list[int]:
# Get the actual number of encoder tokens for each sample. # Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last # Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the # group of images for each sample, which is used to cheat the
...@@ -1318,7 +1318,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1318,7 +1318,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
def flat_encoder_result(self, cross_attention_states: torch.Tensor, def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int]): actual_encoder_seq_lens: list[int]):
cross_attention_states_flat = torch.zeros( cross_attention_states_flat = torch.zeros(
sum(actual_encoder_seq_lens), sum(actual_encoder_seq_lens),
...@@ -1342,8 +1342,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1342,8 +1342,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
image_inputs: MllamaImagePixelInputs, image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: List[int], actual_encoder_seq_lens: list[int],
) -> Tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU # NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data'] pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_ids = image_inputs['aspect_ratio_ids']
...@@ -1367,10 +1367,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1367,10 +1367,10 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
num_tiles: List[List[int]], num_tiles: list[list[int]],
num_tokens_per_tile: int, num_tokens_per_tile: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist() token_ids = input_ids.tolist()
start = 0 start = 0
batch_token_ids = [] batch_token_ids = []
...@@ -1422,7 +1422,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1422,7 +1422,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
**kwargs: object, **kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]: ) -> Union[CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \ if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0: attn_metadata.num_decode_tokens > 0:
...@@ -1476,8 +1476,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1476,8 +1476,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
return outputs return outputs
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
...@@ -1487,7 +1487,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1487,7 +1487,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1), (".gate_up_proj", ".up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
updated_params: Set[str] = set() updated_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
if 'patch_embedding.weight' in name: if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight', name = name.replace('patch_embedding.weight',
...@@ -1538,7 +1538,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1538,7 +1538,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
tower_model="vision_model") tower_model="vision_model")
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: def skip_attention_mask(sparse_mask: list[list[int]]) -> bool:
for mask in sparse_mask: for mask in sparse_mask:
# Skip text-only samples. # Skip text-only samples.
if len(mask) == 0: if len(mask) == 0:
...@@ -1556,10 +1556,10 @@ def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: ...@@ -1556,10 +1556,10 @@ def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
def convert_sparse_cross_attention_mask_to_dense( def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]], sparse_mask: list[list[list[int]]],
num_tiles: List[List[int]], num_tiles: list[list[int]],
lengths: List[int], lengths: list[int],
) -> Tuple[np.ndarray, List[Tuple[int, int]]]: ) -> tuple[np.ndarray, list[tuple[int, int]]]:
total_length = sum(lengths) total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles]) total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import math import math
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from itertools import tee from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -582,7 +582,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -582,7 +582,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> List[PromptUpdate]: ) -> list[PromptUpdate]:
assert ( assert (
mm_items.get_count("image", strict=False) == 0 mm_items.get_count("image", strict=False) == 0
or "aspect_ratios" in out_mm_kwargs or "aspect_ratios" in out_mm_kwargs
...@@ -778,26 +778,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -778,26 +778,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def separate_weights( def separate_weights(
self, self,
weights: Iterable[Tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],
prefix: str, prefix: str,
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[ ) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
str, torch.Tensor]]]: str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2) weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]: def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights1: for name, data in weights1:
if name.startswith(prefix): if name.startswith(prefix):
yield (name, data) yield (name, data)
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]: def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights2: for name, data in weights2:
if not name.startswith(prefix): if not name.startswith(prefix):
yield (name, data) yield (name, data)
return get_prefix_weights(), get_other_weights() return get_prefix_weights(), get_other_weights()
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -806,7 +806,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -806,7 +806,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"), (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
updated_params: Set[str] = set() updated_params: set[str] = set()
# language_model is an Llama4ForCausalLM instance. We load it's # language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine. # using llama4's load_weights routine.
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from typing import Iterable, List, Set, Tuple from collections.abc import Iterable
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -148,7 +148,7 @@ class MLPSpeculator(nn.Module): ...@@ -148,7 +148,7 @@ class MLPSpeculator(nn.Module):
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
num_predict_tokens: int, num_predict_tokens: int,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]: ) -> list[SamplerOutput]:
if num_predict_tokens > self.max_speculative_tokens: if num_predict_tokens > self.max_speculative_tokens:
raise ValueError(f"Max speculative tokens for model is " raise ValueError(f"Max speculative tokens for model is "
f"{self.max_speculative_tokens}, but " f"{self.max_speculative_tokens}, but "
...@@ -190,10 +190,10 @@ class MLPSpeculator(nn.Module): ...@@ -190,10 +190,10 @@ class MLPSpeculator(nn.Module):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
name = name.replace("speculator.", "") name = name.replace("speculator.", "")
param = params_dict.get(name) param = params_dict.get(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment