Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e67b4f2c
Unverified
Commit
e67b4f2c
authored
Sep 11, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 11, 2023
Browse files
Use FP32 in RoPE initialization (#1004)
Co-authored-by:
One
<
imone@tuta.io
>
parent
d6770d1f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+3
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+4
-4
No files found.
tests/kernels/test_pos_encoding.py
View file @
e67b4f2c
...
...
@@ -133,9 +133,10 @@ def test_rotary_embedding(
device
=
"cuda"
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
()
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
vllm/model_executor/layers/attention.py
View file @
e67b4f2c
...
...
@@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
self
.
is_neox_style
=
is_neox_style
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
device
=
"cuda"
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
,
device
=
"cuda"
)
.
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
()
)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment