Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
e67b4f2c
Unverified
Commit
e67b4f2c
authored
Sep 11, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 11, 2023
Browse files
Use FP32 in RoPE initialization (#1004)
Co-authored-by:
One
<
imone@tuta.io
>
parent
d6770d1f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
6 deletions
+7
-6
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+3
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+4
-4
No files found.
tests/kernels/test_pos_encoding.py
View file @
e67b4f2c
...
@@ -133,9 +133,10 @@ def test_rotary_embedding(
...
@@ -133,9 +133,10 @@ def test_rotary_embedding(
device
=
"cuda"
)
device
=
"cuda"
)
# Create the rotary embedding.
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
()
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
vllm/model_executor/layers/attention.py
View file @
e67b4f2c
...
@@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -264,10 +264,10 @@ class PagedAttentionWithRoPE(PagedAttention):
self
.
is_neox_style
=
is_neox_style
self
.
is_neox_style
=
is_neox_style
# Create the cos and sin cache.
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
device
=
"cuda"
)
/
rotary_dim
))
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
,
device
=
"cuda"
)
.
float
()
t
=
torch
.
arange
(
max_position
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
()
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment