Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf601b90
Unverified
Commit
cf601b90
authored
Mar 18, 2023
by
Guangyuan Ma
Committed by
GitHub
Mar 17, 2023
Browse files
Fix Unnecessary move of tensors from CPU to GPU in LlamaRotaryEmbedding (#22234)
push
parent
bec07561
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+6
-6
No files found.
src/transformers/models/llama/modeling_llama.py
View file @
cf601b90
...
...
@@ -99,8 +99,8 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
cos_cached
=
emb
.
cos
()[
None
,
None
,
:,
:]
self
.
sin_cached
=
emb
.
sin
()[
None
,
None
,
:,
:]
self
.
register_buffer
(
"
cos_cached
"
,
emb
.
cos
()[
None
,
None
,
:,
:]
,
persistent
=
False
)
self
.
register_buffer
(
"
sin_cached
"
,
emb
.
sin
()[
None
,
None
,
:,
:]
,
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
# x: [bs, num_attention_heads, seq_len, head_size]
...
...
@@ -111,11 +111,11 @@ class LlamaRotaryEmbedding(torch.nn.Module):
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
self
.
cos_cached
=
emb
.
cos
()[
None
,
None
,
:,
:]
.
to
(
dtype
=
x
.
dtyp
e
)
self
.
sin_cached
=
emb
.
sin
()[
None
,
None
,
:,
:]
.
to
(
dtype
=
x
.
dtyp
e
)
self
.
register_buffer
(
"
cos_cached
"
,
emb
.
cos
()[
None
,
None
,
:,
:]
,
persistent
=
Fals
e
)
self
.
register_buffer
(
"
sin_cached
"
,
emb
.
sin
()[
None
,
None
,
:,
:]
,
persistent
=
Fals
e
)
return
(
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
),
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment