Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
87499420
Unverified
Commit
87499420
authored
Oct 06, 2023
by
rui-ren
Committed by
GitHub
Oct 06, 2023
Browse files
fix RoPE t range issue for fp16 (#26602)
parent
ea52ed9d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+3
-3
No files found.
src/transformers/models/falcon/modeling_falcon.py
View file @
87499420
...
@@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module):
...
@@ -108,7 +108,7 @@ class FalconRotaryEmbedding(nn.Module):
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
seq_len_cached
=
seq_len
self
.
seq_len_cached
=
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
seq_len
,
device
=
device
).
to
(
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
...
@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
...
@@ -171,7 +171,7 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
seq_len_cached
=
seq_len
self
.
seq_len_cached
=
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
seq_len
,
device
=
device
).
to
(
dtype
)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t
=
t
/
self
.
scaling_factor
t
=
t
/
self
.
scaling_factor
...
@@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
...
@@ -208,7 +208,7 @@ class FalconDynamicNTKScalingRotaryEmbedding(FalconRotaryEmbedding):
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
head_dim
,
2
).
float
().
to
(
device
)
/
self
.
head_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
head_dim
,
2
).
float
().
to
(
device
)
/
self
.
head_dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
seq_len
,
device
=
device
).
to
(
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment