Unverified Commit 6223dd81 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update deprecated type hinting in `model_executor/layers` (#18056)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 906f0598
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from functools import cached_property from functools import cached_property
from importlib.util import find_spec from importlib.util import find_spec
from typing import Dict, Optional, Tuple from typing import Optional
import torch import torch
import torch.jit import torch.jit
...@@ -65,7 +65,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -65,7 +65,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None, seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects """Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token tokens proposed by the draft model using the probability of each token
...@@ -161,8 +161,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -161,8 +161,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size] target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[Dict[int, torch.Generator]], seeded_seqs: Optional[dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence. """Perform modified rejection sampling on each sequence.
Returns: Returns:
...@@ -194,7 +194,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -194,7 +194,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
return accepted, recovered_token_ids return accepted, recovered_token_ids
def _create_uniform_samples(self, def _create_uniform_samples(self,
seeded_seqs: Optional[Dict[int, seeded_seqs: Optional[dict[int,
torch.Generator]], torch.Generator]],
batch_size: int, k: int, batch_size: int, k: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
...@@ -210,7 +210,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -210,7 +210,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
a seed. a seed.
Args: Args:
seeded_seqs : Optional[Dict[int, torch.Generator]] seeded_seqs : Optional[dict[int, torch.Generator]]
A dictionary mapping indices in the batch to A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are `torch.Generator` objects. If `None`, all samples are
generated without a seed. generated without a seed.
...@@ -255,7 +255,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -255,7 +255,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size] target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k]
seeded_seqs: Optional[Dict[int, torch.Generator]], seeded_seqs: Optional[dict[int, torch.Generator]],
) -> torch.Tensor: ) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be True, then a token can be accepted, else it should be
...@@ -379,7 +379,7 @@ def _multinomial( ...@@ -379,7 +379,7 @@ def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
k: int, k: int,
seeded_seqs: Dict[int, torch.Generator], seeded_seqs: dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
......
...@@ -33,7 +33,7 @@ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0 ...@@ -33,7 +33,7 @@ Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
""" """
import math import math
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -69,7 +69,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, ...@@ -69,7 +69,7 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor,
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_1d_sincos_pos_embed_from_grid( def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray, embed_dim: int, pos: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor: version: tuple[int, int] = (2, 0)) -> torch.Tensor:
""" """
embed_dim: output dimension for each position embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W) pos: a list of positions to be encoded: size (M,) / (H, W)
...@@ -96,7 +96,7 @@ def get_1d_sincos_pos_embed_from_grid( ...@@ -96,7 +96,7 @@ def get_1d_sincos_pos_embed_from_grid(
def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray, embed_dim: int, grid: np.ndarray,
version: Tuple[int, int] = (2, 0)) -> torch.Tensor: version: tuple[int, int] = (2, 0)) -> torch.Tensor:
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h # use half of dimensions to encode grid_h
...@@ -114,9 +114,9 @@ def get_2d_sincos_pos_embed_from_grid( ...@@ -114,9 +114,9 @@ def get_2d_sincos_pos_embed_from_grid(
def get_2d_sincos_pos_embed( def get_2d_sincos_pos_embed(
embed_dim: int, embed_dim: int,
grid_size: Union[int, Tuple[int, int]], grid_size: Union[int, tuple[int, int]],
cls_token: bool = False, cls_token: bool = False,
version: Tuple[int, int] = (2, 0), version: tuple[int, int] = (2, 0),
) -> torch.Tensor: ) -> torch.Tensor:
""" """
grid_size: int of the grid height and width grid_size: int of the grid height and width
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
import math import math
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -140,7 +140,7 @@ class RotaryEmbedding(CustomOp): ...@@ -140,7 +140,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""A PyTorch-native implementation of forward().""" """A PyTorch-native implementation of forward()."""
if offsets is not None: if offsets is not None:
positions = positions + offsets positions = positions + offsets
...@@ -174,7 +174,7 @@ class RotaryEmbedding(CustomOp): ...@@ -174,7 +174,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
...@@ -202,7 +202,7 @@ class RotaryEmbedding(CustomOp): ...@@ -202,7 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
...@@ -232,7 +232,7 @@ class RotaryEmbedding(CustomOp): ...@@ -232,7 +232,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
from habana_frameworks.torch.hpex.kernels import ( from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb) RotaryPosEmbeddingMode, apply_rotary_pos_emb)
if offsets is not None: if offsets is not None:
...@@ -290,7 +290,7 @@ class RotaryEmbedding(CustomOp): ...@@ -290,7 +290,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
def _apply_rotary_emb_neuron( def _apply_rotary_emb_neuron(
x: torch.Tensor, x: torch.Tensor,
...@@ -406,23 +406,23 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -406,23 +406,23 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factors: Union[List[float], float], scaling_factors: Union[list[float], float],
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
if isinstance(scaling_factors, float): if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors] scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa self.scaling_factors: list[float] = scaling_factors # noqa
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype) is_neox_style, dtype)
# Lazy initialized. # Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int] self._scaling_factor_to_offset: dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
cache_list: List[torch.Tensor] = [] cache_list: list[torch.Tensor] = []
# offsets to the next cache in a tensor. # offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors. # Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = [] offsets: list[int] = []
for scaling_factor in self.scaling_factors: for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original # NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling. # maximum length before applying the rope scaling.
...@@ -452,7 +452,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -452,7 +452,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
return torch.cat(cache_list, dim=0) return torch.cat(cache_list, dim=0)
@property @property
def scaling_factor_to_offset(self) -> Dict[float, int]: def scaling_factor_to_offset(self) -> dict[float, int]:
return self._scaling_factor_to_offset return self._scaling_factor_to_offset
...@@ -512,7 +512,7 @@ def _yarn_find_correction_range( ...@@ -512,7 +512,7 @@ def _yarn_find_correction_range(
high_rot: int, high_rot: int,
dim: int, dim: int,
base: float = 10000, base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]: max_position_embeddings: int = 2048) -> tuple[int, int]:
low = math.floor( low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil( high = math.ceil(
...@@ -613,8 +613,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -613,8 +613,8 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
short_factor: List[float], short_factor: list[float],
long_factor: List[float], long_factor: list[float],
short_mscale: Optional[float] = None, short_mscale: Optional[float] = None,
long_mscale: Optional[float] = None, long_mscale: Optional[float] = None,
): ):
...@@ -662,7 +662,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -662,7 +662,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
long_short_cache, long_short_cache,
persistent=False) persistent=False)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))) 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
...@@ -671,7 +671,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -671,7 +671,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
def _compute_cos_sin_cache( def _compute_cos_sin_cache(
self, self,
max_position_embeddings: int, max_position_embeddings: int,
rescale_factors: List[float], rescale_factors: list[float],
mscale: float, mscale: float,
) -> torch.Tensor: ) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors) inv_freq = self._compute_inv_freq(rescale_factors)
...@@ -688,7 +688,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -688,7 +688,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None assert key is not None
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size)
...@@ -799,7 +799,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -799,7 +799,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
assert key is not None assert key is not None
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
...@@ -930,7 +930,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): ...@@ -930,7 +930,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
self, self,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None assert key is not None
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device) self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape( query_ = torch.view_as_complex(query.float().reshape(
...@@ -958,7 +958,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -958,7 +958,7 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[List[int]] = None, mrope_section: Optional[list[int]] = None,
) -> None: ) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of # In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get # the input video. We enlarge max_position_embeddings to 4 times to get
...@@ -976,7 +976,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -976,7 +976,7 @@ class MRotaryEmbedding(RotaryEmbedding):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward(). """PyTorch-native implementation equivalent to forward().
Args: Args:
...@@ -1024,16 +1024,16 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1024,16 +1024,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod @classmethod
def get_input_positions( def get_input_positions(
cls, cls,
input_tokens: List[int], input_tokens: list[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
video_grid_thw: Optional[Union[List[List[int]], torch.Tensor]], video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
second_per_grid_ts: Optional[List[float]], second_per_grid_ts: Optional[list[float]],
context_len: int = 0, context_len: int = 0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None, audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False, use_audio_in_video: bool = False,
) -> Tuple[List[List[int]], int]: ) -> tuple[list[list[int]], int]:
"""Get mrope input positions and delta value.""" """Get mrope input positions and delta value."""
image_grid_thw = [] if image_grid_thw is None else image_grid_thw image_grid_thw = [] if image_grid_thw is None else image_grid_thw
...@@ -1059,16 +1059,16 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1059,16 +1059,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod @classmethod
def get_input_positions_tensor( def get_input_positions_tensor(
cls, cls,
input_tokens: List[int], input_tokens: list[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor], image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: List[float], second_per_grid_ts: list[float],
context_len: int = 0, context_len: int = 0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None, audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False, use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
from vllm.transformers_utils.config import thinker_uses_mrope from vllm.transformers_utils.config import thinker_uses_mrope
if thinker_uses_mrope(hf_config): if thinker_uses_mrope(hf_config):
return cls._omni_get_input_positions_tensor( return cls._omni_get_input_positions_tensor(
...@@ -1096,14 +1096,14 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1096,14 +1096,14 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod @classmethod
def _vl_get_input_positions_tensor( def _vl_get_input_positions_tensor(
cls, cls,
input_tokens: List[int], input_tokens: list[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor], image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: List[float], second_per_grid_ts: list[float],
context_len: int = 0, context_len: int = 0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value.""" """Get mrope input positions and delta value."""
image_token_id = hf_config.image_token_id image_token_id = hf_config.image_token_id
...@@ -1195,16 +1195,16 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1195,16 +1195,16 @@ class MRotaryEmbedding(RotaryEmbedding):
@classmethod @classmethod
def _omni_get_input_positions_tensor( def _omni_get_input_positions_tensor(
cls, cls,
input_tokens: List[int], input_tokens: list[int],
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor], image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[List[float]] = None, second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0, context_len: int = 0,
seq_len: Optional[int] = None, seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None, audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False, use_audio_in_video: bool = False,
) -> Tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version). """Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding: Differences from MRotaryEmbedding:
...@@ -1329,7 +1329,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1329,7 +1329,7 @@ class MRotaryEmbedding(RotaryEmbedding):
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2 pure_audio_len = place_num - 2
added_audio_len = 0 added_audio_len = 0
audio_llm_pos_ids_list: List[torch.Tensor] = [] audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk: for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = len( vision_ntoken_per_chunk = len(
t_chunk) * grid_h * grid_w // (spatial_merge_size**2) t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
...@@ -1382,7 +1382,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1382,7 +1382,7 @@ class MRotaryEmbedding(RotaryEmbedding):
start_idx: int, start_idx: int,
vision_idx: int, vision_idx: int,
spatial_merge_size: int, spatial_merge_size: int,
t_index: List[int], t_index: list[int],
grid_hs: torch.Tensor, grid_hs: torch.Tensor,
grid_ws: torch.Tensor, grid_ws: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1402,8 +1402,8 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1402,8 +1402,8 @@ class MRotaryEmbedding(RotaryEmbedding):
@staticmethod @staticmethod
def _split_list_into_ranges(lst: torch.Tensor, def _split_list_into_ranges(lst: torch.Tensor,
interval: int) -> List[List[int]]: interval: int) -> list[list[int]]:
ranges: List[List[int]] = [[] ranges: list[list[int]] = [[]
for _ in range((max(lst) // interval) + 1)] for _ in range((max(lst) // interval) + 1)]
for num in lst: for num in lst:
index = num // interval index = num // interval
...@@ -1415,7 +1415,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1415,7 +1415,7 @@ class MRotaryEmbedding(RotaryEmbedding):
mrope_position_delta: int, mrope_position_delta: int,
context_len: int, context_len: int,
seq_len: int, seq_len: int,
) -> List[List[int]]: ) -> list[list[int]]:
return [ return [
list( list(
range(context_len + mrope_position_delta, range(context_len + mrope_position_delta,
...@@ -1438,9 +1438,9 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1438,9 +1438,9 @@ class MRotaryEmbedding(RotaryEmbedding):
cls, cls,
thinker_config: PretrainedConfig, thinker_config: PretrainedConfig,
audio_len: int, audio_len: int,
video_grid_thw: Union[List[int], torch.Tensor], video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float, video_second_per_grid_t: float,
) -> List[int]: ) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True. """Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into In this case, audio and vision update ids will be split into
...@@ -1593,7 +1593,7 @@ class DualChunkRotaryEmbedding(CustomOp): ...@@ -1593,7 +1593,7 @@ class DualChunkRotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim] query_rot = query[..., :self.rotary_dim]
...@@ -1664,7 +1664,7 @@ class DualChunkRotaryEmbedding(CustomOp): ...@@ -1664,7 +1664,7 @@ class DualChunkRotaryEmbedding(CustomOp):
return s return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
def get_rope( def get_rope(
...@@ -1673,10 +1673,10 @@ def get_rope( ...@@ -1673,10 +1673,10 @@ def get_rope(
max_position: int, max_position: int,
base: int, base: int,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None, dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
import itertools import itertools
import warnings import warnings
from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from importlib.util import find_spec from importlib.util import find_spec
from math import inf from math import inf
from typing import Dict, Iterator, List, Optional, Tuple, Union from typing import Optional, Union
import msgspec import msgspec
import torch import torch
...@@ -42,14 +43,14 @@ def get_sampler() -> torch.nn.Module: ...@@ -42,14 +43,14 @@ def get_sampler() -> torch.nn.Module:
# (num_token_ids, num_parent_ids) per sequence group. # (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]] SampleResultType = list[tuple[list[int], list[int]]]
# Types of temporary data structures used for # Types of temporary data structures used for
# computing sample_result # computing sample_result
SampleMetadataType = Dict[SamplingType, Tuple[List[int], SampleMetadataType = dict[SamplingType, tuple[list[int],
List[SequenceGroupToSample]]] list[SequenceGroupToSample]]]
MultinomialSamplesType = Dict[SamplingType, torch.Tensor] MultinomialSamplesType = dict[SamplingType, torch.Tensor]
SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] SampleResultsDictType = dict[int, tuple[list[int], list[int]]]
# Encapsulates temporary data structures for computing # Encapsulates temporary data structures for computing
...@@ -76,7 +77,7 @@ class SampleResultArgsType: ...@@ -76,7 +77,7 @@ class SampleResultArgsType:
MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType]
# Abbreviation of the _sample() return type # Abbreviation of the _sample() return type
SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]]
class SamplerOutput( class SamplerOutput(
...@@ -90,7 +91,7 @@ class SamplerOutput( ...@@ -90,7 +91,7 @@ class SamplerOutput(
also has optional fields for device tensors. also has optional fields for device tensors.
""" """
outputs: List[CompletionSequenceGroupOutput] outputs: list[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token. # On-device tensor containing probabilities of each token.
sampled_token_probs: Optional[torch.Tensor] = None sampled_token_probs: Optional[torch.Tensor] = None
...@@ -350,7 +351,7 @@ def _apply_min_tokens_penalty( ...@@ -350,7 +351,7 @@ def _apply_min_tokens_penalty(
have not been generated yet have not been generated yet
""" """
# list of indices in logits that will be set to -inf # list of indices in logits that will be set to -inf
logits_to_penalize: List[Tuple[int, int]] = [] logits_to_penalize: list[tuple[int, int]] = []
logits_applied = 0 logits_applied = 0
for seq_group in sampling_metadata.seq_groups: for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
...@@ -366,7 +367,7 @@ def _apply_min_tokens_penalty( ...@@ -366,7 +367,7 @@ def _apply_min_tokens_penalty(
min_tokens = sampling_params.min_tokens min_tokens = sampling_params.min_tokens
token_ids_to_penalize = sampling_params.all_stop_token_ids token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize: if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize: List[int] = [] seqs_to_penalize: list[int] = []
for j, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids_array) < min_tokens: if len(seq_data.output_token_ids_array) < min_tokens:
...@@ -436,7 +437,7 @@ def _apply_min_p( ...@@ -436,7 +437,7 @@ def _apply_min_p(
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: list[SequenceGroupToSample],
samples: torch.Tensor, samples: torch.Tensor,
) -> SampleResultType: ) -> SampleResultType:
"""Run greedy sampling on a given samples. """Run greedy sampling on a given samples.
...@@ -471,7 +472,7 @@ def _greedy_sample( ...@@ -471,7 +472,7 @@ def _greedy_sample(
def _random_sample( def _random_sample(
selected_seq_groups: List[SequenceGroupToSample], selected_seq_groups: list[SequenceGroupToSample],
random_samples: torch.Tensor, random_samples: torch.Tensor,
) -> SampleResultType: ) -> SampleResultType:
"""Run random sampling on a given samples. """Run random sampling on a given samples.
...@@ -522,7 +523,7 @@ def _random_sample( ...@@ -522,7 +523,7 @@ def _random_sample(
def _multinomial( def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
seq_groups: Optional[List[SequenceGroupToSample]] = None, seq_groups: Optional[list[SequenceGroupToSample]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0) probs = probs.repeat_interleave(num_samples, dim=0)
...@@ -543,7 +544,7 @@ def _multinomial( ...@@ -543,7 +544,7 @@ def _multinomial(
def _top_k_top_p_multinomial_with_flashinfer( def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]):
max_top_k_round = 32 max_top_k_round = 32
if num_samples > 1: if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0) probs = probs.repeat_interleave(num_samples, dim=0)
...@@ -648,7 +649,7 @@ def _sample_with_torch( ...@@ -648,7 +649,7 @@ def _sample_with_torch(
tensors required for Pythonization tensors required for Pythonization
''' '''
categorized_seq_group_ids: Dict[SamplingType, List[int]] = { categorized_seq_group_ids: dict[SamplingType, list[int]] = {
t: [] t: []
for t in SamplingType for t in SamplingType
} }
...@@ -812,7 +813,7 @@ def get_logprobs( ...@@ -812,7 +813,7 @@ def get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: SampleResultType, sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: ) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]:
"""Return sample logprobs and prompt logprobs. """Return sample logprobs and prompt logprobs.
The logic consists of 3 parts. The logic consists of 3 parts.
...@@ -841,9 +842,9 @@ def get_logprobs( ...@@ -841,9 +842,9 @@ def get_logprobs(
""" """
# The index of query token to calculate logprobs. It includes both # The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices. # prompt and sample logprob indices.
query_indices: List[int] = [] query_indices: list[int] = []
# The next token ids to get the logprob value from. # The next token ids to get the logprob value from.
next_token_ids: List[int] = [] next_token_ids: list[int] = []
# The largest requested number of logprobs. We find logprobs as many as the # The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API. If every logprobs is None, it will be # largest num logprobs in this API. If every logprobs is None, it will be
# set to -1. # set to -1.
...@@ -925,8 +926,8 @@ def get_logprobs( ...@@ -925,8 +926,8 @@ def get_logprobs(
ranks = ranks.to('cpu') ranks = ranks.to('cpu')
# Find prompt/sample logprobs. # Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = []
sample_logprobs_per_seq_group: List[SampleLogprobs] = [] sample_logprobs_per_seq_group: list[SampleLogprobs] = []
top_logprob_idx = 0 top_logprob_idx = 0
selected_logprobs_idx = 0 selected_logprobs_idx = 0
...@@ -977,7 +978,7 @@ def _get_prompt_logprob_if_needed( ...@@ -977,7 +978,7 @@ def _get_prompt_logprob_if_needed(
for idx, token_id in enumerate(next_prompt_tokens): for idx, token_id in enumerate(next_prompt_tokens):
# Calculate the prompt logprob of the real prompt tokens. # Calculate the prompt logprob of the real prompt tokens.
# {token_id: (logprob, rank_from_vocab)} # {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { prompt_logprobs_dict: dict[int, tuple[float, int]] = {
token_id: (selected_logprob_items[idx], rank_items[idx]) token_id: (selected_logprob_items[idx], rank_items[idx])
} }
...@@ -1009,7 +1010,7 @@ def _get_prompt_logprob_if_needed( ...@@ -1009,7 +1010,7 @@ def _get_prompt_logprob_if_needed(
def _get_sampled_logprob_if_needed( def _get_sampled_logprob_if_needed(
seq_group: SequenceGroupToSample, seq_group: SequenceGroupToSample,
sample_result: Tuple[List[int], List[int]], sample_result: tuple[list[int], list[int]],
selected_logprobs: torch.Tensor, selected_logprobs: torch.Tensor,
ranks: torch.Tensor, ranks: torch.Tensor,
top_token_ids: torch.Tensor, top_token_ids: torch.Tensor,
...@@ -1130,9 +1131,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -1130,9 +1131,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output( def _build_sampler_output(
maybe_deferred_sample_results: MaybeDeferredSampleResultType, maybe_deferred_sample_results: MaybeDeferredSampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], prompt_logprobs: Optional[list[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]], sample_logprobs: Optional[list[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor,
torch.Tensor]], torch.Tensor]],
skip_sampler_cpu_output: bool = False, skip_sampler_cpu_output: bool = False,
) -> SamplerOutput: ) -> SamplerOutput:
...@@ -1144,7 +1145,7 @@ def _build_sampler_output( ...@@ -1144,7 +1145,7 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling. speculative decoding rejection sampling.
""" """
sampler_output: List[CompletionSequenceGroupOutput] = [] sampler_output: list[CompletionSequenceGroupOutput] = []
if skip_sampler_cpu_output: if skip_sampler_cpu_output:
assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) assert isinstance(maybe_deferred_sample_results, SampleResultArgsType)
...@@ -1166,7 +1167,7 @@ def _build_sampler_output( ...@@ -1166,7 +1167,7 @@ def _build_sampler_output(
prompt_logprobs, sample_logprobs): prompt_logprobs, sample_logprobs):
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = [] seq_outputs: list[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip( for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs): parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append( seq_outputs.append(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, Optional, Union from typing import Optional, Union
import torch import torch
import torch.jit import torch.jit
...@@ -253,6 +253,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): ...@@ -253,6 +253,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
seeded_seqs: Optional[Dict[int, torch.Generator]] = None, seeded_seqs: Optional[dict[int, torch.Generator]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Utility methods for model layers.""" """Utility methods for model layers."""
from typing import Callable, Optional, Tuple from typing import Callable, Optional
import torch import torch
...@@ -13,7 +13,7 @@ def get_token_bin_counts_and_mask( ...@@ -13,7 +13,7 @@ def get_token_bin_counts_and_mask(
tokens: torch.Tensor, tokens: torch.Tensor,
vocab_size: int, vocab_size: int,
num_seqs: int, num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Compute the bin counts for the tokens. # Compute the bin counts for the tokens.
# vocab_size + 1 for padding. # vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1), bin_counts = torch.zeros((num_seqs, vocab_size + 1),
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -25,7 +26,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
"""Create weights for embedding layer.""" """Create weights for embedding layer."""
...@@ -141,7 +142,7 @@ def get_masked_input_and_mask( ...@@ -141,7 +142,7 @@ def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int, input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int, org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int, added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: added_vocab_end_index: int) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below # torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast # into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & ( org_vocab_mask = (input_ >= org_vocab_start_index) & (
...@@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -298,7 +299,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_vocab_start_index, org_vocab_end_index, org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index) added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]: def get_sharded_to_full_mapping(self) -> Optional[list[int]]:
"""Get a mapping that can be used to reindex the gathered """Get a mapping that can be used to reindex the gathered
logits for sampling. logits for sampling.
...@@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -312,9 +313,9 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size < 2: if self.tp_size < 2:
return None return None
base_embeddings: List[int] = [] base_embeddings: list[int] = []
added_embeddings: List[int] = [] added_embeddings: list[int] = []
padding: List[int] = [] padding: list[int] = []
for tp_rank in range(self.tp_size): for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded, shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded, self.org_vocab_size_padded,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment