Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5febdc87
Unverified
Commit
5febdc87
authored
Sep 13, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 13, 2025
Browse files
[Chore] Remove unused batched RoPE op & kernel (#24789)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
99bfef84
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
16 additions
and
348 deletions
+16
-348
csrc/ops.h
csrc/ops.h
+0
-6
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+0
-122
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-10
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+1
-146
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+6
-16
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-10
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+0
-11
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+9
-27
No files found.
csrc/ops.h
View file @
5febdc87
...
@@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
...
@@ -122,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
silu_and_mul_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
...
...
csrc/pos_encoding_kernels.cu
View file @
5febdc87
...
@@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
...
@@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
token_idx
,
query_stride
,
key_stride
,
head_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
batched_rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
// namespace vllm
}
// namespace vllm
void
rotary_embedding
(
void
rotary_embedding
(
...
@@ -211,96 +182,3 @@ void rotary_embedding(
...
@@ -211,96 +182,3 @@ void rotary_embedding(
}
}
});
});
}
}
/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std
::
optional
<
torch
::
Tensor
>
key
,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens] or [batch_size]
)
{
// num_tokens = batch_size * seq_len
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
TORCH_CHECK
(
positions
.
size
(
0
)
==
num_tokens
||
positions
.
numel
()
==
num_tokens
,
"positions must have the same num_tokens or batch_size as "
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)),
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
))
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)),
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/torch_bindings.cpp
View file @
5febdc87
...
@@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -214,16 +214,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops
.
def
(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()"
);
ops
.
impl
(
"batched_rotary_embedding"
,
torch
::
kCUDA
,
&
batched_rotary_embedding
);
// Quantization ops
// Quantization ops
#ifndef USE_ROCM
#ifndef USE_ROCM
// Quantized GEMM for AWQ.
// Quantized GEMM for AWQ.
...
...
tests/kernels/core/test_pos_encoding.py
View file @
5febdc87
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
itertools
import
accumulate
,
product
from
itertools
import
product
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
pytest
...
@@ -111,151 +111,6 @@ def test_rotary_embedding(
...
@@ -111,151 +111,6 @@ def test_rotary_embedding(
"expected returned key to be None"
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"tensor_shape_fn"
,
TENSORS_SHAPES_FN
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding
(
is_neox_style
:
bool
,
tensor_shape_fn
:
Callable
[[
int
,
int
,
int
,
int
],
tuple
[
int
]],
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
base
:
float
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
1
,
)
})
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
torch
.
get_default_device
())
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# slice tensor if required, noop otherwise
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
dtype
=
torch
.
long
,
device
=
device
))
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
if
use_key
:
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding_multi_lora
(
is_neox_style
:
bool
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
base
:
float
=
10000
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
scaling_factors
:
list
[
int
]
=
[
1
,
2
,
4
]
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
})
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
torch
.
get_default_device
())
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
offset_map
=
torch
.
tensor
(
list
(
accumulate
([
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
])))
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
query_offsets
=
offset_map
[
query_types
]
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
,
query_offsets
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
# Compare the results.
torch
.
testing
.
assert_close
(
out_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
if
use_key
:
torch
.
testing
.
assert_close
(
out_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rope_module_cache
():
def
test_rope_module_cache
():
MAX_POSITIONS
=
[
123
,
1234
]
MAX_POSITIONS
=
[
123
,
1234
]
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
5febdc87
...
@@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
...
@@ -16,20 +16,14 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def
rotary_embedding_opcheck
(
rot
,
def
rotary_embedding_opcheck
(
rot
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
key
:
Optional
[
torch
.
Tensor
]
=
None
):
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding() is a in-place operation
# are in-place operations that update the query and key tensors.
# that updates the query and key tensors.
if
offsets
is
not
None
:
opcheck
(
torch
.
ops
.
_C
.
rotary_embedding
,
opcheck
(
torch
.
ops
.
_C
.
batched_rotary_embedding
,
(
positions
,
query
,
key
,
rot
.
head_size
,
cos_sin_cache
,
(
positions
,
query
,
key
,
rot
.
head_size
,
cos_sin_cache
,
rot
.
is_neox_style
))
rot
.
is_neox_style
,
rot
.
rotary_dim
,
offsets
))
else
:
opcheck
(
torch
.
ops
.
_C
.
rotary_embedding
,
(
positions
,
query
,
key
,
rot
.
head_size
,
cos_sin_cache
,
rot
.
is_neox_style
))
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cuda"
])
...
@@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -65,10 +59,6 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
key
=
key
[...,
:
head_size
]
if
use_key
else
None
key
=
key
[...,
:
head_size
]
if
use_key
else
None
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
device
=
device
,
dtype
=
torch
.
long
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
# if we have a contiguous head stride, test the alternate
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
# [..., num_heads * head_dim] shape/layout
...
...
vllm/_custom_ops.py
View file @
5febdc87
...
@@ -257,16 +257,6 @@ def rotary_embedding(
...
@@ -257,16 +257,6 @@ def rotary_embedding(
cos_sin_cache
,
is_neox
)
cos_sin_cache
,
is_neox
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
# layer norm ops
# layer norm ops
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rms_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
epsilon
:
float
)
->
None
:
...
...
vllm/_ipex_ops.py
View file @
5febdc87
...
@@ -148,17 +148,6 @@ class ipex_ops:
...
@@ -148,17 +148,6 @@ class ipex_ops:
head_size
,
cos_sin_cache
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
)
is_neox
,
rot_dim
)
@
staticmethod
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
ipex
.
llm
.
functional
.
rotary_embedding_batched
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
@
staticmethod
@
staticmethod
def
rms_norm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
def
rms_norm
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
torch
.
Tensor
:
epsilon
:
float
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/rotary_embedding/base.py
View file @
5febdc87
...
@@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp):
...
@@ -62,11 +62,8 @@ class RotaryEmbedding(CustomOp):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""A PyTorch-native implementation of forward()."""
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
...
@@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp):
...
@@ -96,7 +93,6 @@ class RotaryEmbedding(CustomOp):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp):
...
@@ -107,16 +103,10 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding() is an in-place operation
# are in-place operations that update the query and key tensors.
# that updates the query and key tensors.
if
offsets
is
not
None
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_xpu
(
def
forward_xpu
(
...
@@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp):
...
@@ -124,29 +114,21 @@ class RotaryEmbedding(CustomOp):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
# ops.rotary_embedding()
/batched_rotary_embedding()
# ops.rotary_embedding()
is an in-place operation
#
are in-place operations
that update the query and key tensors.
# that update
s
the query and key tensors.
if
key
is
None
:
if
key
is
None
:
# XPU kernel doesn't support key=None so fall back to native impl
# XPU kernel doesn't support key=None so fall back to native impl
# TODO(sarckk): add support for optional key in
# TODO(sarckk): add support for optional key in
# ipex.llm.functional.rotary_embedding_batched
# ipex.llm.functional.rotary_embedding_batched
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
)
else
:
else
:
if
offsets
is
not
None
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment