Unverified Commit 2add697d authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: remove vllm get_rope (#2964)

parent 6f98c586
# Copyright 2023-2024 SGLang Team # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. """Rotary Positional Embeddings."""
# You may obtain a copy of the License at import math
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MRotaryEmbedding"""
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from vllm.model_executor.layers.rotary_embedding import ( import torch.nn as nn
RotaryEmbedding, from vllm.model_executor.custom_op import CustomOp
_rotate_gptj,
_rotate_neox,
_yarn_find_correction_range,
_yarn_linear_ramp_mask,
get_rope,
yarn_get_mscale,
)
class MRotaryEmbedding:
"""Rotary Embedding with Multimodal Sections."""
@staticmethod
def get_input_positions(
input_tokens: torch.Tensor,
image_grid_thw: Union[List[List[int]], torch.Tensor],
vision_start_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor): def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
image_grid_thw = image_grid_thw.tolist() x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
vision_start_indices = torch.argwhere(
input_tokens == vision_start_token_id
).squeeze(1)
image_indices = vision_start_indices + 1
image_nums = image_indices.shape[0]
llm_pos_ids_list: list = []
st = 0 def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
input_tokens_len = input_tokens.shape[0] x1 = x[..., ::2]
for image_index in range(image_nums): x2 = x[..., 1::2]
ed = image_indices[image_index].item() x = torch.stack((-x2, x1), dim=-1)
t, h, w = ( return x.flatten(-2)
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2], def _apply_rotary_emb(
) x: torch.Tensor,
llm_grid_t, llm_grid_h, llm_grid_w = ( cos: torch.Tensor,
t, sin: torch.Tensor,
h // spatial_merge_size, is_neox_style: bool,
w // spatial_merge_size, ) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
) )
text_len = ed - st )
return inv_freq
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 def _compute_cos_sin_cache(self) -> torch.Tensor:
llm_pos_ids_list.append( """Compute the cos and sin cache."""
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
) )
else:
ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key
t_index = ( def forward_xpu(
torch.arange(llm_grid_t) self,
.view(-1, 1) positions: torch.Tensor,
.expand(-1, llm_grid_h * llm_grid_w) query: torch.Tensor,
.flatten() key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
self.rotary_dim,
offsets,
) )
h_index = ( else:
torch.arange(llm_grid_h) ops.rotary_embedding(
.view(1, -1, 1) positions,
.expand(llm_grid_t, -1, llm_grid_w) query,
.flatten() key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
) )
w_index = ( return query, key
torch.arange(llm_grid_w)
.view(1, 1, -1) def forward_hpu(
.expand(llm_grid_t, llm_grid_h, -1) self,
.flatten() positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
positions = positions.flatten()
if offsets is not None:
positions = positions + offsets
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1)
cos, sin = cos_sin.chunk(2, dim=-1)
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode: RotaryPosEmbeddingMode
if self.is_neox_style:
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
else:
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1])
cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1])
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list: List[torch.Tensor] = []
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets: List[int] = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
if not cache_list:
offset = 0
else:
last_offset = offsets[-1]
next_max_len = cache_list[-1].shape[0]
offset = last_offset + next_max_len
offsets.append(offset)
cache_list.append(cache)
self._scaling_factor_to_offset = {
float(scaling_factor): offsets[i]
for i, scaling_factor in enumerate(self.scaling_factors)
}
assert len(self.scaling_factors) == len(offsets)
return torch.cat(cache_list, dim=0)
@property
def scaling_factor_to_offset(self) -> Dict[float, int]:
return self._scaling_factor_to_offset
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.rotary_dim / (self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(
num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> float:
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
# Find dim range bounds based on rotations
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048,
) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def _yarn_linear_ramp_mask(
low: float, high: float, dim: int, dtype: torch.dtype
) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.rotary_dim,
self.base,
self.max_position_embeddings,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1
- _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
) * self.extrapolation_factor
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor, dtype=torch.float32
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_mscale: Optional[float] = None,
long_mscale: Optional[float] = None,
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
rotary_dim != head_size ({rotary_dim}!={head_size})."
) )
llm_pos_ids_list.append( if is_neox_style is False:
torch.stack([t_index, h_index, w_index]) + text_len + st_idx raise ValueError(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
) )
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < input_tokens_len: self.head_size = head_size
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 self.max_position_embeddings = max_position_embeddings
text_len = input_tokens_len - st self.original_max_position_embeddings = original_max_position_embeddings
llm_pos_ids_list.append( self.base = base
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx self.short_factor = short_factor
self.long_factor = long_factor
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
) )
if short_mscale is None:
short_mscale = scaling_factor
if long_mscale is None:
long_mscale = scaling_factor
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) self.short_mscale = short_mscale
llm_positions = llm_positions[:, context_len:] self.long_mscale = long_mscale
mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item()
return llm_positions.tolist(), mrope_position_delta
@staticmethod short_cache = self._compute_cos_sin_cache(
def get_next_input_positions( original_max_position_embeddings, short_factor, short_mscale
mrope_position_delta: int, )
context_len: int, short_cache = short_cache.to(dtype)
seq_len: int, self.register_buffer("short_cos_sin_cache", short_cache, persistent=False)
) -> List[List[int]]:
return [ long_cache = self._compute_cos_sin_cache(
list( max_position_embeddings, long_factor, long_mscale
range( )
context_len + mrope_position_delta, seq_len + mrope_position_delta long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache", long_cache, persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0
)
self.register_buffer(
"long_short_cos_sin_cache", long_short_cache, persistent=False
)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (
rescale_factors
* (
self.base
** (
torch.arange(0, self.head_size, 2, dtype=torch.float)
/ self.head_size
) )
) )
for _ in range(3) )
] return inv_freq
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (
torch.any(positions > k).float() * torch.full_like(positions, k)
).long()
idx = (
torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None
else positions
)
self.long_short_cos_sin_cache: torch.Tensor = self.long_short_cos_sin_cache.to(
idx.device
)
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
return query.flatten(-2), key.flatten(-2)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm,
# the device has been hard-coded to "cuda" in these two places:
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665
# We port the related code to this file to make it compatible with the CPU version.
# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then.
class DeepseekScalingRotaryEmbedding(RotaryEmbedding): class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method. """RotaryEmbedding extended with YaRN method.
...@@ -149,7 +664,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -149,7 +664,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1, beta_slow: int = 1,
mscale: float = 1, mscale: float = 1,
mscale_all_dim: float = 0, mscale_all_dim: float = 0,
device: Optional[str] = None,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
...@@ -162,14 +676,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -162,14 +676,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor * attn_factor
) )
self.device = device
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** ( pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda")
/ self.rotary_dim / self.rotary_dim
) )
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
...@@ -197,7 +710,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -197,7 +710,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange( t = torch.arange(
self.max_position_embeddings * self.scaling_factor, self.max_position_embeddings * self.scaling_factor,
device=self.device, device="cuda",
dtype=torch.float32, dtype=torch.float32,
) )
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
...@@ -248,10 +761,249 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -248,10 +761,249 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
return query, key return query, key
class Llama3RotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
wave_len = 2 * math.pi / inv_freqs
if self.low_freq_factor != self.high_freq_factor:
smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
else:
smooth = 0
new_freqs = torch.where(
wave_len < high_freq_wavelen,
inv_freqs,
torch.where(
wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs,
),
)
return new_freqs
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
self.mrope_section = mrope_section
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if positions.ndim == 2:
assert self.mrope_section
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))],
dim=-1,
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))],
dim=-1,
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@staticmethod
def get_input_positions(
input_tokens: List[int],
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int,
video_token_id: int,
vision_start_token_id: int,
vision_end_token_id: int,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
"""Get mrope input positions and delta value."""
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
video_grid_thw = video_grid_thw.tolist()
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions.tolist(), mrope_position_delta
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
context_len: int,
seq_len: int,
) -> List[List[int]]:
return [
list(
range(
context_len + mrope_position_delta, seq_len + mrope_position_delta
)
)
for _ in range(3)
]
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
def get_rope_cpu( def get_rope(
head_size: int, head_size: int,
rotary_dim: int, rotary_dim: int,
max_position: int, max_position: int,
...@@ -260,7 +1012,6 @@ def get_rope_cpu( ...@@ -260,7 +1012,6 @@ def get_rope_cpu(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
...@@ -286,75 +1037,140 @@ def get_rope_cpu( ...@@ -286,75 +1037,140 @@ def get_rope_cpu(
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
assert rope_scaling is not None if rope_scaling is None:
scaling_type = rope_scaling["rope_type"] rotary_emb = RotaryEmbedding(
assert ( head_size, rotary_dim, max_position, base, is_neox_style, dtype
scaling_type == "deepseek_yarn"
), "Only deepseek_yarn is supported for CPU for now"
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
) )
} else:
extra_kwargs["device"] = device scaling_type = rope_scaling["rope_type"]
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
else:
rotary_emb = RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size,
rotary_dim,
max_position,
original_max_position,
base,
is_neox_style,
dtype,
short_factor,
long_factor,
**extra_kwargs,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
return rotary_emb return rotary_emb
def get_rope_wrapper(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
device,
)
...@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple ...@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple ...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import ChatGLMConfig from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
...@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import ( ...@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -44,7 +44,6 @@ import torch.utils.checkpoint ...@@ -44,7 +44,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import ( ...@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
......
...@@ -19,7 +19,6 @@ from typing import Iterable, Optional, Tuple ...@@ -19,7 +19,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig from sglang.srt.configs import DbrxConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
......
...@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -23,7 +23,6 @@ import torch.nn.functional as F ...@@ -23,7 +23,6 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -49,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -49,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -272,7 +271,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -272,7 +271,7 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
......
...@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import ( ...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple ...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
...@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import ( ...@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
# Adapted from: # Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from typing import Iterable, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import ( ...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -45,10 +45,6 @@ def get_attention_sliding_window_size(config): ...@@ -45,10 +45,6 @@ def get_attention_sliding_window_size(config):
return config.sliding_window - 1 return config.sliding_window - 1
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class Gemma2MLP(nn.Module): class Gemma2MLP(nn.Module):
def __init__( def __init__(
self, self,
......
...@@ -25,8 +25,6 @@ from transformers import GPT2Config ...@@ -25,8 +25,6 @@ from transformers import GPT2Config
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
# from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
......
...@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GraniteConfig from transformers import GraniteConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO ...@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -22,7 +22,6 @@ import torch ...@@ -22,7 +22,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import ( ...@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO ...@@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -18,7 +18,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -18,7 +18,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -31,6 +30,7 @@ from sglang.srt.layers.linear import ( ...@@ -31,6 +30,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import ( ...@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple ...@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE ...@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import ( ...@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
...@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple ...@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import OlmoConfig from transformers import OlmoConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
...@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import ( ...@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment