Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1aa2f81b
Unverified
Commit
1aa2f81b
authored
May 30, 2025
by
Cyrus Leung
Committed by
GitHub
May 30, 2025
Browse files
[Misc] Update type annotation for rotary embedding `base` (#18914)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
d54af615
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
26 deletions
+23
-26
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+1
-1
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+3
-3
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+17
-17
vllm/model_executor/models/minimax_text_01.py
vllm/model_executor/models/minimax_text_01.py
+2
-5
No files found.
benchmarks/kernels/benchmark_rope.py
View file @
1aa2f81b
...
...
@@ -22,7 +22,7 @@ def benchmark_rope_kernels_multi_lora(
seed
:
int
,
device
:
str
,
max_position
:
int
=
8192
,
base
:
in
t
=
10000
,
base
:
floa
t
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
tests/kernels/core/test_pos_encoding.py
View file @
1aa2f81b
...
...
@@ -70,7 +70,7 @@ def test_rotary_embedding(
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
base
:
in
t
=
10000
,
base
:
floa
t
=
10000
,
)
->
None
:
if
rotary_dim
is
None
:
rotary_dim
=
head_size
...
...
@@ -135,7 +135,7 @@ def test_batched_rotary_embedding(
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
base
:
in
t
=
10000
,
base
:
floa
t
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
@@ -203,7 +203,7 @@ def test_batched_rotary_embedding_multi_lora(
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
base
:
in
t
=
10000
,
base
:
floa
t
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
1aa2f81b
...
...
@@ -96,7 +96,7 @@ class RotaryEmbedding(CustomOp):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
...
...
@@ -113,7 +113,7 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
]
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
...
...
@@ -404,7 +404,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
scaling_factors
:
Union
[
list
[
float
],
float
],
dtype
:
torch
.
dtype
,
...
...
@@ -464,7 +464,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
...
...
@@ -474,7 +474,7 @@ class NTKScalingRotaryEmbedding(RotaryEmbedding):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
]
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
base
=
self
.
base
*
(
self
.
scaling_factor
if
self
.
mixed_b
is
None
else
1
)
inv_freq
=
super
().
_compute_inv_freq
(
base
)
...
...
@@ -501,7 +501,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
...
...
@@ -582,7 +582,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
...
...
@@ -644,7 +644,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
rotary_dim
:
int
,
max_position_embeddings
:
int
,
original_max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
short_factor
:
list
[
float
],
...
...
@@ -769,7 +769,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
...
...
@@ -877,7 +877,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
scaling_factor
:
float
,
...
...
@@ -892,7 +892,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
]
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
low_freq_wavelen
=
self
.
orig_max_position
/
self
.
low_freq_factor
high_freq_wavelen
=
self
.
orig_max_position
/
self
.
high_freq_factor
...
...
@@ -923,14 +923,14 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
):
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
]
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
inv_freqs
=
super
().
_compute_inv_freq
(
base
)
inv_freqs
=
inv_freqs
[:(
self
.
rotary_dim
//
2
)]
return
inv_freqs
...
...
@@ -989,7 +989,7 @@ class MRotaryEmbedding(RotaryEmbedding):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
mrope_section
:
Optional
[
list
[
int
]]
=
None
,
...
...
@@ -1529,7 +1529,7 @@ class DualChunkRotaryEmbedding(CustomOp):
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
chunk_size
:
int
,
...
...
@@ -1558,7 +1558,7 @@ class DualChunkRotaryEmbedding(CustomOp):
q_inter_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
]
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
...
...
@@ -1705,7 +1705,7 @@ def get_rope(
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
=
True
,
rope_scaling
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
...
...
vllm/model_executor/models/minimax_text_01.py
View file @
1aa2f81b
...
...
@@ -141,7 +141,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
head_size
:
int
,
rotary_dim
:
int
,
max_position
:
int
,
base
:
in
t
,
base
:
floa
t
,
is_neox_style
:
bool
,
cache_dtype
:
torch
.
dtype
,
)
->
None
:
...
...
@@ -155,10 +155,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
cache
=
self
.
_compute_cos_sin_cache
().
to
(
cache_dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
],
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment