Unverified Commit 1aa2f81b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Update type annotation for rotary embedding `base` (#18914)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d54af615
...@@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora( ...@@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora(
seed: int, seed: int,
device: str, device: str,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: float = 10000,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
......
...@@ -70,7 +70,7 @@ def test_rotary_embedding( ...@@ -70,7 +70,7 @@ def test_rotary_embedding(
device: str, device: str,
use_key: bool, use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: float = 10000,
) -> None: ) -> None:
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
...@@ -135,7 +135,7 @@ def test_batched_rotary_embedding( ...@@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
device: str, device: str,
use_key: bool, use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: float = 10000,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
...@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
device: str, device: str,
use_key: bool, use_key: bool,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: float = 10000,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
......
...@@ -96,7 +96,7 @@ class RotaryEmbedding(CustomOp): ...@@ -96,7 +96,7 @@ class RotaryEmbedding(CustomOp):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
) -> None: ) -> None:
...@@ -113,7 +113,7 @@ class RotaryEmbedding(CustomOp): ...@@ -113,7 +113,7 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to # NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we # use CPU to compute the cache and then move it to GPU. However, we
...@@ -404,7 +404,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -404,7 +404,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
scaling_factors: Union[list[float], float], scaling_factors: Union[list[float], float],
dtype: torch.dtype, dtype: torch.dtype,
...@@ -464,7 +464,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -464,7 +464,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -474,7 +474,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -474,7 +474,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype) is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
base = self.base * (self.scaling_factor if self.mixed_b is None else 1) base = self.base * (self.scaling_factor if self.mixed_b is None else 1)
inv_freq = super()._compute_inv_freq(base) inv_freq = super()._compute_inv_freq(base)
...@@ -501,7 +501,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -501,7 +501,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -582,7 +582,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -582,7 +582,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -644,7 +644,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): ...@@ -644,7 +644,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
original_max_position_embeddings: int, original_max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
short_factor: list[float], short_factor: list[float],
...@@ -769,7 +769,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -769,7 +769,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -877,7 +877,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -877,7 +877,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
scaling_factor: float, scaling_factor: float,
...@@ -892,7 +892,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -892,7 +892,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype) is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base) inv_freqs = super()._compute_inv_freq(base)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor high_freq_wavelen = self.orig_max_position / self.high_freq_factor
...@@ -923,14 +923,14 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): ...@@ -923,14 +923,14 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
): ):
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype) is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base) inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[:(self.rotary_dim // 2)] inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
return inv_freqs return inv_freqs
...@@ -989,7 +989,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -989,7 +989,7 @@ class MRotaryEmbedding(RotaryEmbedding):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
mrope_section: Optional[list[int]] = None, mrope_section: Optional[list[int]] = None,
...@@ -1529,7 +1529,7 @@ class DualChunkRotaryEmbedding(CustomOp): ...@@ -1529,7 +1529,7 @@ class DualChunkRotaryEmbedding(CustomOp):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
chunk_size: int, chunk_size: int,
...@@ -1558,7 +1558,7 @@ class DualChunkRotaryEmbedding(CustomOp): ...@@ -1558,7 +1558,7 @@ class DualChunkRotaryEmbedding(CustomOp):
q_inter_cache, q_inter_cache,
persistent=False) persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to # However, we use `torch.arange(..., dtype=torch.float)` instead to
...@@ -1705,7 +1705,7 @@ def get_rope( ...@@ -1705,7 +1705,7 @@ def get_rope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
base: int, base: float,
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[dict[str, Any]] = None, rope_scaling: Optional[dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
......
...@@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp): ...@@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
base: int, base: float,
is_neox_style: bool, is_neox_style: bool,
cache_dtype: torch.dtype, cache_dtype: torch.dtype,
) -> None: ) -> None:
...@@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp): ...@@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
cache = self._compute_cos_sin_cache().to(cache_dtype) cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq( def _compute_inv_freq(self, base: float) -> torch.Tensor:
self,
base: Union[int, float],
) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange( inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment