Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2add697d
Unverified
Commit
2add697d
authored
Jan 18, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 18, 2025
Browse files
feat: remove vllm get_rope (#2964)
parent
6f98c586
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1016 additions
and
207 deletions
+1016
-207
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+996
-180
python/sglang/srt/models/baichuan.py
python/sglang/srt/models/baichuan.py
+1
-1
python/sglang/srt/models/chatglm.py
python/sglang/srt/models/chatglm.py
+1
-1
python/sglang/srt/models/commandr.py
python/sglang/srt/models/commandr.py
+1
-1
python/sglang/srt/models/dbrx.py
python/sglang/srt/models/dbrx.py
+1
-1
python/sglang/srt/models/deepseek.py
python/sglang/srt/models/deepseek.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-3
python/sglang/srt/models/exaone.py
python/sglang/srt/models/exaone.py
+1
-1
python/sglang/srt/models/gemma.py
python/sglang/srt/models/gemma.py
+1
-1
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+2
-6
python/sglang/srt/models/gpt2.py
python/sglang/srt/models/gpt2.py
+0
-2
python/sglang/srt/models/granite.py
python/sglang/srt/models/granite.py
+1
-1
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+1
-1
python/sglang/srt/models/internlm2.py
python/sglang/srt/models/internlm2.py
+1
-1
python/sglang/srt/models/llama.py
python/sglang/srt/models/llama.py
+1
-1
python/sglang/srt/models/minicpm.py
python/sglang/srt/models/minicpm.py
+1
-1
python/sglang/srt/models/minicpm3.py
python/sglang/srt/models/minicpm3.py
+1
-1
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+1
-1
python/sglang/srt/models/mixtral_quant.py
python/sglang/srt/models/mixtral_quant.py
+1
-1
python/sglang/srt/models/olmo.py
python/sglang/srt/models/olmo.py
+1
-1
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
2add697d
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MRotaryEmbedding"""
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py
"""Rotary Positional Embeddings."""
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
,
_rotate_gptj
,
_rotate_neox
,
_yarn_find_correction_range
,
_yarn_linear_ramp_mask
,
get_rope
,
yarn_get_mscale
,
)
class
MRotaryEmbedding
:
"""Rotary Embedding with Multimodal Sections."""
import
torch.nn
as
nn
from
vllm.model_executor.custom_op
import
CustomOp
@
staticmethod
def
get_input_positions
(
input_tokens
:
torch
.
Tensor
,
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
vision_start_token_id
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens
==
vision_start_token_id
).
squeeze
(
1
)
image_indices
=
vision_start_indices
+
1
image_nums
=
image_indices
.
shape
[
0
]
llm_pos_ids_list
:
list
=
[]
st
=
0
input_tokens_len
=
input_tokens
.
shape
[
0
]
for
image_index
in
range
(
image_nums
):
ed
=
image_indices
[
image_index
].
item
()
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
,
def
_rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
text_len
=
ed
-
st
)
return
inv_freq
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
,
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
return
query
,
key
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
flatten
()
def
forward_xpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
,
)
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
return
query
,
key
def
forward_hpu
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
habana_frameworks.torch.hpex.kernels
import
(
RotaryPosEmbeddingMode
,
apply_rotary_pos_emb
,
)
positions
=
positions
.
flatten
()
if
offsets
is
not
None
:
positions
=
positions
+
offsets
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
).
view
(
num_tokens
,
1
,
-
1
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
# to query hidden dimension, so the original tensors need to be
# expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
# GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE
# and expansion of cos/sin tensors via repeat_interleave
rope_mode
:
RotaryPosEmbeddingMode
if
self
.
is_neox_style
:
rope_mode
=
RotaryPosEmbeddingMode
.
BLOCKWISE
cos
=
torch
.
cat
((
cos
,
cos
),
dim
=-
1
)
sin
=
torch
.
cat
((
sin
,
sin
),
dim
=-
1
)
else
:
rope_mode
=
RotaryPosEmbeddingMode
.
PAIRWISE
sin
=
torch
.
repeat_interleave
(
sin
,
2
,
dim
=-
1
,
output_size
=
cos_sin
.
shape
[
-
1
])
cos
=
torch
.
repeat_interleave
(
cos
,
2
,
dim
=-
1
,
output_size
=
cos_sin
.
shape
[
-
1
])
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
apply_rotary_pos_emb
(
query_rot
,
cos
,
sin
,
None
,
0
,
rope_mode
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
apply_rotary_pos_emb
(
key_rot
,
cos
,
sin
,
None
,
0
,
rope_mode
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
head_size
}
, rotary_dim=
{
self
.
rotary_dim
}
"
s
+=
f
", max_position_embeddings=
{
self
.
max_position_embeddings
}
"
s
+=
f
", base=
{
self
.
base
}
, is_neox_style=
{
self
.
is_neox_style
}
"
return
s
class
LinearScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with linear scaling.
It supports multiple scaling factors. Since multiple LoRA adapters may have
different scaling factors, we need multiple cos/sin caches. In this way,
instead of running rotary embedding kernel per lora, we can run multiple
lora in a batched way.
In addition to that, we also keep the cos/sin cache for the scaling factor
of 1 (default) at all times.
Exemplary for two scaling factors x=1, y and z with embeddings
[[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and
[[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and
[[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]],
we construct the cos/sin cache as follows:
[[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p],
...
[xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]]
We then use offsets to index into the cos/sin cache for
the respective scaling factors.
The offset to cache can be accessed via `scaling_factor_to_offset` API.
Credits to the Reddit user /u/kaiokendev
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
List
[
float
],
float
],
dtype
:
torch
.
dtype
,
)
->
None
:
if
isinstance
(
scaling_factors
,
float
):
scaling_factors
=
[
scaling_factors
]
self
.
scaling_factors
:
List
[
float
]
=
scaling_factors
# noqa
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
# Lazy initialized.
self
.
_scaling_factor_to_offset
:
Dict
[
float
,
int
]
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
cache_list
:
List
[
torch
.
Tensor
]
=
[]
# offsets to the next cache in a tensor.
# Each offset corresponds to the same index in scaling_factors.
offsets
:
List
[
int
]
=
[]
for
scaling_factor
in
self
.
scaling_factors
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
scaling_factor
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
t
=
t
/
scaling_factor
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
if
not
cache_list
:
offset
=
0
else
:
last_offset
=
offsets
[
-
1
]
next_max_len
=
cache_list
[
-
1
].
shape
[
0
]
offset
=
last_offset
+
next_max_len
offsets
.
append
(
offset
)
cache_list
.
append
(
cache
)
self
.
_scaling_factor_to_offset
=
{
float
(
scaling_factor
):
offsets
[
i
]
for
i
,
scaling_factor
in
enumerate
(
self
.
scaling_factors
)
}
assert
len
(
self
.
scaling_factors
)
==
len
(
offsets
)
return
torch
.
cat
(
cache_list
,
dim
=
0
)
@
property
def
scaling_factor_to_offset
(
self
)
->
Dict
[
float
,
int
]:
return
self
.
_scaling_factor_to_offset
class
DynamicNTKScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len
=
self
.
max_position_embeddings
*
self
.
scaling_factor
base
=
self
.
base
*
(
(
self
.
scaling_factor
*
max_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
)
)
**
(
self
.
rotary_dim
/
(
self
.
rotary_dim
-
2
))
inv_freq
=
self
.
_compute_inv_freq
(
base
)
t
=
torch
.
arange
(
max_len
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
# Inverse dim formula to find dim based on number of rotations
def
_yarn_find_correction_dim
(
num_rotations
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
,
)
->
float
:
return
(
dim
*
math
.
log
(
max_position_embeddings
/
(
num_rotations
*
2
*
math
.
pi
)))
/
(
2
*
math
.
log
(
base
)
)
# Find dim range bounds based on rotations
def
_yarn_find_correction_range
(
low_rot
:
int
,
high_rot
:
int
,
dim
:
int
,
base
:
float
=
10000
,
max_position_embeddings
:
int
=
2048
,
)
->
Tuple
[
int
,
int
]:
low
=
math
.
floor
(
_yarn_find_correction_dim
(
low_rot
,
dim
,
base
,
max_position_embeddings
)
)
high
=
math
.
ceil
(
_yarn_find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
)
)
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
# Clamp values just in case
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
if
low
==
high
:
high
+=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
def
_yarn_get_mscale
(
scale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
math
.
log
(
scale
)
+
1.0
class
YaRNScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
)
)
*
self
.
extrapolation_factor
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
self
.
mscale
sin
=
freqs
.
sin
()
*
self
.
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
class
Phi3LongRoPEScaledRotaryEmbedding
(
nn
.
Module
):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
List
[
float
],
long_factor
:
List
[
float
],
short_mscale
:
Optional
[
float
]
=
None
,
long_mscale
:
Optional
[
float
]
=
None
,
):
super
().
__init__
()
if
rotary_dim
!=
head_size
:
raise
ValueError
(
f
"`Phi3LongRoPEScaledRotaryEmbedding` does not support
\
rotary_dim != head_size (
{
rotary_dim
}
!=
{
head_size
}
)."
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
if
is_neox_style
is
False
:
raise
ValueError
(
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
input_tokens_len
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
input_tokens_len
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
self
.
head_size
=
head_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
original_max_position_embeddings
=
original_max_position_embeddings
self
.
base
=
base
self
.
short_factor
=
short_factor
self
.
long_factor
=
long_factor
scale
=
self
.
max_position_embeddings
/
self
.
original_max_position_embeddings
if
scale
<=
1.0
:
scaling_factor
=
1.0
else
:
scaling_factor
=
math
.
sqrt
(
1
+
math
.
log
(
scale
)
/
math
.
log
(
self
.
original_max_position_embeddings
)
)
if
short_mscale
is
None
:
short_mscale
=
scaling_factor
if
long_mscale
is
None
:
long_mscale
=
scaling_factor
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
llm_positions
=
llm_positions
[:,
context_len
:]
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
input_tokens_len
).
item
()
return
llm_positions
.
tolist
(),
mrope_position_delta
self
.
short_mscale
=
short_mscale
self
.
long_mscale
=
long_mscale
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
context_len
:
int
,
seq_len
:
int
,
)
->
List
[
List
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
short_cache
=
self
.
_compute_cos_sin_cache
(
original_max_position_embeddings
,
short_factor
,
short_mscale
)
short_cache
=
short_cache
.
to
(
dtype
)
self
.
register_buffer
(
"short_cos_sin_cache"
,
short_cache
,
persistent
=
False
)
long_cache
=
self
.
_compute_cos_sin_cache
(
max_position_embeddings
,
long_factor
,
long_mscale
)
long_cache
=
long_cache
.
to
(
dtype
)
self
.
register_buffer
(
"long_cos_sin_cache"
,
long_cache
,
persistent
=
False
)
long_short_cache
=
torch
.
cat
(
[
self
.
short_cos_sin_cache
,
self
.
long_cos_sin_cache
],
dim
=
0
)
self
.
register_buffer
(
"long_short_cos_sin_cache"
,
long_short_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
rescale_factors
:
List
[
float
])
->
torch
.
Tensor
:
rescale_factors
=
torch
.
tensor
(
rescale_factors
,
dtype
=
torch
.
float32
)
inv_freq
=
1.0
/
(
rescale_factors
*
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
head_size
,
2
,
dtype
=
torch
.
float
)
/
self
.
head_size
)
)
for
_
in
range
(
3
)
]
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
,
max_position_embeddings
:
int
,
rescale_factors
:
List
[
float
],
mscale
:
float
,
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
rescale_factors
)
t
=
torch
.
arange
(
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
*
mscale
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
k
=
self
.
original_max_position_embeddings
long_prompt_offset
=
(
torch
.
any
(
positions
>
k
).
float
()
*
torch
.
full_like
(
positions
,
k
)
).
long
()
idx
=
(
torch
.
add
(
positions
,
long_prompt_offset
)
if
long_prompt_offset
is
not
None
else
positions
)
self
.
long_short_cos_sin_cache
:
torch
.
Tensor
=
self
.
long_short_cos_sin_cache
.
to
(
idx
.
device
)
idx
=
torch
.
add
(
idx
,
offsets
)
if
offsets
is
not
None
else
idx
cos_sin
=
torch
.
index_select
(
self
.
long_short_cos_sin_cache
,
0
,
idx
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
2
).
unsqueeze
(
-
2
)
query
=
query
*
cos
+
_rotate_neox
(
query
)
*
sin
key
=
key
*
cos
+
_rotate_neox
(
key
)
*
sin
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
# TODO: in the DeepseekScalingRotaryEmbedding class defined in vllm,
# the device has been hard-coded to "cuda" in these two places:
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L646
# https://github.com/vllm-project/vllm/blob/8a1f938e6f02052df0f4953c149410605a2d56d8/vllm/model_executor/layers/rotary_embedding.py#L665
# We port the related code to this file to make it compatible with the CPU version.
# We will add an optimized rotary embedding kernel for CPU and will remove the ported code then.
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
...
...
@@ -149,7 +664,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
device
:
Optional
[
str
]
=
None
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
...
...
@@ -162,14 +676,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
self
.
device
=
device
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
device
)
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
...
...
@@ -197,7 +710,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
self
.
device
,
device
=
"cuda"
,
dtype
=
torch
.
float32
,
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
...
...
@@ -248,10 +761,249 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
return
query
,
key
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
low_freq_factor
:
float
,
high_freq_factor
:
float
,
orig_max_position
:
int
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
low_freq_factor
=
low_freq_factor
self
.
high_freq_factor
=
high_freq_factor
self
.
orig_max_position
=
orig_max_position
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
wave_len
=
2
*
math
.
pi
/
inv_freqs
if
self
.
low_freq_factor
!=
self
.
high_freq_factor
:
smooth
=
(
self
.
orig_max_position
/
wave_len
-
self
.
low_freq_factor
)
/
(
self
.
high_freq_factor
-
self
.
low_freq_factor
)
else
:
smooth
=
0
new_freqs
=
torch
.
where
(
wave_len
<
high_freq_wavelen
,
inv_freqs
,
torch
.
where
(
wave_len
>
low_freq_wavelen
,
inv_freqs
/
self
.
scaling_factor
,
(
1
-
smooth
)
*
inv_freqs
/
self
.
scaling_factor
+
smooth
*
inv_freqs
,
),
)
return
new_freqs
class
MRotaryEmbedding
(
RotaryEmbedding
):
"""Rotary Embedding with Multimodal Sections."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
self
.
mrope_section
=
mrope_section
if
self
.
mrope_section
:
assert
sum
(
self
.
mrope_section
)
==
rotary_dim
//
2
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
positions
.
ndim
==
2
:
assert
self
.
mrope_section
cos
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
cos
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
sin
=
torch
.
cat
(
[
m
[
i
]
for
i
,
m
in
enumerate
(
sin
.
split
(
self
.
mrope_section
,
dim
=-
1
))],
dim
=-
1
,
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
@
staticmethod
def
get_input_positions
(
input_tokens
:
List
[
int
],
image_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
video_grid_thw
:
Union
[
List
[
List
[
int
]],
torch
.
Tensor
],
image_token_id
:
int
,
video_token_id
:
int
,
vision_start_token_id
:
int
,
vision_end_token_id
:
int
,
spatial_merge_size
:
int
,
context_len
:
int
=
0
,
seq_len
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
List
[
List
[
int
]],
int
]:
"""Get mrope input positions and delta value."""
if
isinstance
(
image_grid_thw
,
torch
.
Tensor
):
image_grid_thw
=
image_grid_thw
.
tolist
()
if
isinstance
(
video_grid_thw
,
torch
.
Tensor
):
video_grid_thw
=
video_grid_thw
.
tolist
()
input_tokens_tensor
=
torch
.
tensor
(
input_tokens
)
vision_start_indices
=
torch
.
argwhere
(
input_tokens_tensor
==
vision_start_token_id
).
squeeze
(
1
)
vision_tokens
=
input_tokens_tensor
[
vision_start_indices
+
1
]
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
vision_tokens
==
video_token_id
).
sum
()
llm_pos_ids_list
:
list
=
[]
st
=
0
remain_images
,
remain_videos
=
image_nums
,
video_nums
image_index
,
video_index
=
0
,
0
for
_
in
range
(
image_nums
+
video_nums
):
if
image_token_id
in
input_tokens
and
remain_images
>
0
:
ed_image
=
input_tokens
.
index
(
image_token_id
,
st
)
else
:
ed_image
=
len
(
input_tokens
)
+
1
if
video_token_id
in
input_tokens
and
remain_videos
>
0
:
ed_video
=
input_tokens
.
index
(
video_token_id
,
st
)
else
:
ed_video
=
len
(
input_tokens
)
+
1
if
ed_image
<
ed_video
:
t
,
h
,
w
=
(
image_grid_thw
[
image_index
][
0
],
image_grid_thw
[
image_index
][
1
],
image_grid_thw
[
image_index
][
2
],
)
image_index
+=
1
remain_images
-=
1
ed
=
ed_image
else
:
t
,
h
,
w
=
(
video_grid_thw
[
video_index
][
0
],
video_grid_thw
[
video_index
][
1
],
video_grid_thw
[
video_index
][
2
],
)
video_index
+=
1
remain_videos
-=
1
ed
=
ed_video
llm_grid_t
,
llm_grid_h
,
llm_grid_w
=
(
t
,
h
//
spatial_merge_size
,
w
//
spatial_merge_size
,
)
text_len
=
ed
-
st
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
flatten
()
)
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
llm_grid_t
,
llm_grid_h
,
-
1
)
.
flatten
()
)
llm_pos_ids_list
.
append
(
torch
.
stack
([
t_index
,
h_index
,
w_index
])
+
text_len
+
st_idx
)
st
=
ed
+
llm_grid_t
*
llm_grid_h
*
llm_grid_w
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
len
(
llm_pos_ids_list
)
>
0
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
mrope_position_delta
=
(
llm_positions
.
max
()
+
1
-
len
(
input_tokens
)).
item
()
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
.
tolist
(),
mrope_position_delta
@
staticmethod
def
get_next_input_positions
(
mrope_position_delta
:
int
,
context_len
:
int
,
seq_len
:
int
,
)
->
List
[
List
[
int
]]:
return
[
list
(
range
(
context_len
+
mrope_position_delta
,
seq_len
+
mrope_position_delta
)
)
for
_
in
range
(
3
)
]
_ROPE_DICT
:
Dict
[
Tuple
,
RotaryEmbedding
]
=
{}
def
get_rope
_cpu
(
def
get_rope
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
...
...
@@ -260,7 +1012,6 @@ def get_rope_cpu(
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
device
:
Optional
[
str
]
=
None
,
)
->
RotaryEmbedding
:
if
dtype
is
None
:
dtype
=
torch
.
get_default_dtype
()
...
...
@@ -286,75 +1037,140 @@ def get_rope_cpu(
if
key
in
_ROPE_DICT
:
return
_ROPE_DICT
[
key
]
assert
rope_scaling
is
not
None
scaling_type
=
rope_scaling
[
"rope_type"
]
assert
(
scaling_type
==
"deepseek_yarn"
),
"Only deepseek_yarn is supported for CPU for now"
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
,
if
rope_scaling
is
None
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
)
}
extra_kwargs
[
"device"
]
=
device
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
else
:
scaling_type
=
rope_scaling
[
"rope_type"
]
if
scaling_type
==
"llama3"
:
scaling_factor
=
rope_scaling
[
"factor"
]
low_freq_factor
=
rope_scaling
[
"low_freq_factor"
]
high_freq_factor
=
rope_scaling
[
"high_freq_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
rotary_emb
=
Llama3RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
scaling_factor
,
low_freq_factor
,
high_freq_factor
,
original_max_position
,
)
elif
scaling_type
==
"default"
:
if
"mrope_section"
in
rope_scaling
:
rotary_emb
=
MRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
mrope_section
=
rope_scaling
[
"mrope_section"
],
)
else
:
rotary_emb
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
dtype
,
)
elif
scaling_type
==
"linear"
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
LinearScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
elif
scaling_type
==
"dynamic"
:
scaling_factor
=
rope_scaling
[
"factor"
]
rotary_emb
=
DynamicNTKScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
)
elif
scaling_type
==
"yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
)
}
rotary_emb
=
YaRNScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
elif
scaling_type
==
"deepseek_yarn"
:
scaling_factor
=
rope_scaling
[
"factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
,
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
,
)
elif
scaling_type
==
"longrope"
:
short_factor
=
rope_scaling
[
"short_factor"
]
long_factor
=
rope_scaling
[
"long_factor"
]
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"short_mscale"
,
"long_mscale"
)
}
rotary_emb
=
Phi3LongRoPEScaledRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
original_max_position
,
base
,
is_neox_style
,
dtype
,
short_factor
,
long_factor
,
**
extra_kwargs
,
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
_ROPE_DICT
[
key
]
=
rotary_emb
return
rotary_emb
def
get_rope_wrapper
(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
int
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
partial_rotary_factor
:
float
=
1.0
,
device
:
Optional
[
str
]
=
None
,
):
if
device
!=
"cpu"
:
return
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling
,
dtype
,
partial_rotary_factor
,
)
return
get_rope_cpu
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
rope_scaling
,
dtype
,
partial_rotary_factor
,
device
,
)
python/sglang/srt/models/baichuan.py
View file @
2add697d
...
...
@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/chatglm.py
View file @
2add697d
...
...
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.configs
import
ChatGLMConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
...
...
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/commandr.py
View file @
2add697d
...
...
@@ -44,7 +44,6 @@ import torch.utils.checkpoint
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
python/sglang/srt/models/dbrx.py
View file @
2add697d
...
...
@@ -19,7 +19,6 @@ from typing import Iterable, Optional, Tuple
import
torch
import
torch.nn
as
nn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.configs
import
DbrxConfig
from
sglang.srt.distributed
import
(
...
...
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
...
...
python/sglang/srt/models/deepseek.py
View file @
2add697d
...
...
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
fused_moe
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
2add697d
...
...
@@ -23,7 +23,6 @@ import torch.nn.functional as F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -49,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
_wrapper
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -272,7 +271,7 @@ class DeepseekV2Attention(nn.Module):
quant_config
=
quant_config
,
)
rope_scaling
[
"rope_type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
_wrapper
(
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
...
...
python/sglang/srt/models/exaone.py
View file @
2add697d
...
...
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/gemma.py
View file @
2add697d
...
...
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
GeluAndMul
...
...
@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
python/sglang/srt/models/gemma2.py
View file @
2add697d
...
...
@@ -15,12 +15,11 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
GeluAndMul
...
...
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -45,10 +45,6 @@ def get_attention_sliding_window_size(config):
return
config
.
sliding_window
-
1
# FIXME: temporary solution, remove after next vllm release
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
class
Gemma2MLP
(
nn
.
Module
):
def
__init__
(
self
,
...
...
python/sglang/srt/models/gpt2.py
View file @
2add697d
...
...
@@ -25,8 +25,6 @@ from transformers import GPT2Config
from
sglang.srt.distributed.parallel_state
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
get_act_fn
# from sglang.srt.layers.activation import get_act_fn
from
sglang.srt.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
python/sglang/srt/models/granite.py
View file @
2add697d
...
...
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
GraniteConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/grok.py
View file @
2add697d
...
...
@@ -22,7 +22,6 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/internlm2.py
View file @
2add697d
...
...
@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/llama.py
View file @
2add697d
...
...
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/minicpm.py
View file @
2add697d
...
...
@@ -18,7 +18,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -31,6 +30,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/minicpm3.py
View file @
2add697d
...
...
@@ -19,7 +19,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/mixtral.py
View file @
2add697d
...
...
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_world_size
,
...
...
@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/mixtral_quant.py
View file @
2add697d
...
...
@@ -23,7 +23,6 @@ import torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
...
...
@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
python/sglang/srt/models/olmo.py
View file @
2add697d
...
...
@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
import
torch
from
torch
import
nn
from
transformers
import
OlmoConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment