Unverified Commit 054072be authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Minor] Move RoPE selection logic to `get_rope` (#1633)

parent eb825c1e
...@@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, ...@@ -10,9 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import get_rope
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding, YaRNScalingRotaryEmbedding)
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
...@@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -319,36 +317,8 @@ class PagedAttentionWithRoPE(PagedAttention):
scale, scale,
num_kv_heads, num_kv_heads,
sliding_window=sliding_window) sliding_window=sliding_window)
if rope_scaling is None: self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, is_neox_style, rope_scaling)
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor",
"beta_fast", "beta_slow")
}
self.rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
import math import math
from typing import Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -271,3 +271,46 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
sin = (freqs.sin() * self.mscale) sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool,
rope_scaling: Optional[Dict[str, Any]],
) -> RotaryEmbedding:
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
return rotary_emb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment