Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ab7165f2
Unverified
Commit
ab7165f2
authored
Aug 18, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 18, 2024
Browse files
[TPU] Optimize RoPE forward_native2 (#7636)
parent
0c2fa50b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
26 deletions
+27
-26
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+27
-26
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
ab7165f2
...
...
@@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
x_
=
torch
.
view_as_complex
(
torch
.
stack
(
torch
.
chunk
(
x
.
transpose
(
1
,
2
).
float
(),
2
,
dim
=-
1
),
dim
=-
1
))
x_out
=
torch
.
view_as_real
(
x_
*
freqs_cis
).
type_as
(
x
)
x_out
=
torch
.
cat
(
torch
.
chunk
(
x_out
,
2
,
dim
=-
1
),
dim
=-
2
)
x_out
=
x_out
.
reshape
(
x_out
.
shape
[
0
],
x_out
.
shape
[
1
],
x_out
.
shape
[
2
],
-
1
).
transpose
(
1
,
2
)
return
x_out
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
"""
orig_dtype
=
x
.
dtype
x
=
x
.
float
()
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
cos
=
cos
.
unsqueeze
(
-
2
)
sin
=
sin
.
unsqueeze
(
-
2
)
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
).
to
(
orig_dtype
)
class
RotaryEmbedding
(
CustomOp
):
...
...
@@ -78,14 +86,10 @@ class RotaryEmbedding(CustomOp):
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
use_native2
=
current_platform
.
is_tpu
()
and
is_neox_style
if
not
self
.
use_native2
:
cache
=
cache
.
to
(
dtype
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
else
:
cos
,
sin
=
cache
.
chunk
(
2
,
dim
=-
1
)
freqs_cis
=
cos
+
1j
*
sin
self
.
register_buffer
(
"freqs_cis"
,
freqs_cis
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
...
...
@@ -173,28 +177,25 @@ class RotaryEmbedding(CustomOp):
This method might perform better than `forward_native()` when compiled.
"""
if
positions
.
dim
()
==
1
:
batch_size
=
1
seq_len
=
positions
.
shape
[
0
]
else
:
batch_size
,
seq_len
=
positions
.
shape
if
offsets
is
not
None
:
positions
=
positions
+
offsets
freqs_cis
=
self
.
freqs_cis
.
index_select
(
0
,
positions
.
flatten
())
freqs_cis
=
freqs_cis
.
view
(
batch_size
,
1
,
seq_len
,
-
1
)
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
freqs_cis
)
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
freqs_cis
)
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment