Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
083e13b7
Unverified
Commit
083e13b7
authored
Aug 02, 2024
by
Joao Gante
Committed by
GitHub
Aug 02, 2024
Browse files
RoPE: Add numerical tests
✨
(#32380)
tests! :D
parent
2af199c4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
329 additions
and
5 deletions
+329
-5
src/transformers/modeling_rope_utils.py
src/transformers/modeling_rope_utils.py
+9
-4
tests/utils/test_modeling_rope_utils.py
tests/utils/test_modeling_rope_utils.py
+320
-1
No files found.
src/transformers/modeling_rope_utils.py
View file @
083e13b7
...
...
@@ -150,7 +150,7 @@ def _compute_dynamic_ntk_parameters(
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
...
@@ -210,7 +210,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
max
+=
0.001
# Prevent singularity
...
...
@@ -218,6 +218,8 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
...
...
@@ -225,8 +227,11 @@ def _compute_yarn_parameters(
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
...
...
tests/utils/test_modeling_rope_utils.py
View file @
083e13b7
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
import
math
import
unittest
from
transformers
import
LlamaConfig
...
...
@@ -116,5 +117,323 @@ class RopeTest(unittest.TestCase):
kwargs_freqs
=
rope_fn
(
**
rope_kwargs
,
device
=
device
)[
0
]
torch
.
testing
.
assert_close
(
config_freqs
,
kwargs_freqs
)
def
test_default_rope_numerically
(
self
):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.6596e-02
,
7.4989e-02
,
6.4938e-02
,
5.6234e-02
,
4.8697e-02
,
4.2170e-02
,
3.6517e-02
,
3.1623e-02
,
2.7384e-02
,
2.3714e-02
,
2.0535e-02
,
1.7783e-02
,
1.5399e-02
,
1.3335e-02
,
1.1548e-02
,
1.0000e-02
,
8.6596e-03
,
7.4989e-03
,
6.4938e-03
,
5.6234e-03
,
4.8697e-03
,
4.2170e-03
,
3.6517e-03
,
3.1623e-03
,
2.7384e-03
,
2.3714e-03
,
2.0535e-03
,
1.7783e-03
,
1.5399e-03
,
1.3335e-03
,
1.1548e-03
,
1.0000e-03
,
8.6596e-04
,
7.4989e-04
,
6.4938e-04
,
5.6234e-04
,
4.8697e-04
,
4.2170e-04
,
3.6517e-04
,
3.1623e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
],
device
=
torch_device
)
# fmt: on
# TODO(joao): numerical checks for the different RoPE fns
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for default RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_linear_rope_numerically
(
self
):
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
# frequencies (= the inverse frequencies are scaled **inversely**)
config
=
LlamaConfig
()
default_rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
default_rope_fn
(
config
=
config
,
device
=
torch_device
)
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"linear"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"linear"
,
"factor"
:
factor
}
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for linear RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
factor
)
def
test_dynamic_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.0931e-01
,
6.5498e-01
,
5.3008e-01
,
4.2900e-01
,
3.4720e-01
,
2.8099e-01
,
2.2741e-01
,
1.8404e-01
,
1.4895e-01
,
1.2055e-01
,
9.7558e-02
,
7.8955e-02
,
6.3899e-02
,
5.1714e-02
,
4.1853e-02
,
3.3872e-02
,
2.7413e-02
,
2.2185e-02
,
1.7955e-02
,
1.4531e-02
,
1.1760e-02
,
9.5176e-03
,
7.7027e-03
,
6.2339e-03
,
5.0451e-03
,
4.0831e-03
,
3.3045e-03
,
2.6744e-03
,
2.1644e-03
,
1.7517e-03
,
1.4176e-03
,
1.1473e-03
,
9.2852e-04
,
7.5146e-04
,
6.0817e-04
,
4.9220e-04
,
3.9834e-04
,
3.2238e-04
,
2.6091e-04
,
2.1115e-04
,
1.7089e-04
,
1.3830e-04
,
1.1193e-04
,
9.0585e-05
,
7.3312e-05
,
5.9332e-05
,
4.8018e-05
,
3.8861e-05
,
3.1451e-05
,
2.5453e-05
,
2.0600e-05
,
1.6672e-05
,
1.3492e-05
,
1.0920e-05
,
8.8374e-06
,
7.1522e-06
,
5.7883e-06
,
4.6845e-06
,
3.7912e-06
,
3.0683e-06
,
2.4832e-06
,
2.0097e-06
,
1.6265e-06
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
# model's original training sequence length
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"dynamic"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"dynamic"
,
"factor"
:
factor
}
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for dynamic RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
)
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
)
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
# will scale up (i.e., the inverse frequencies will scale down).
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"dynamic"
,
"factor"
:
factor
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
16384
)
with
self
.
assertRaises
(
AssertionError
):
# It is NOT a linear factor
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
factor
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_yarn_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.3479e-02
,
6.9590e-02
,
5.7925e-02
,
4.8136e-02
,
3.9931e-02
,
3.3061e-02
,
2.7315e-02
,
2.2515e-02
,
1.8512e-02
,
1.5177e-02
,
1.2403e-02
,
1.0101e-02
,
8.1924e-03
,
6.6143e-03
,
5.3120e-03
,
4.2400e-03
,
3.3599e-03
,
2.6396e-03
,
2.0520e-03
,
1.5746e-03
,
1.1882e-03
,
8.7713e-04
,
6.2810e-04
,
4.3007e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
,
1.0000e-04
,
8.6596e-05
,
7.4989e-05
,
6.4938e-05
,
5.6234e-05
,
4.8697e-05
,
4.2170e-05
,
3.6517e-05
,
3.1623e-05
,
2.7384e-05
,
2.3714e-05
,
2.0535e-05
,
1.7783e-05
,
1.5399e-05
,
1.3335e-05
,
1.1548e-05
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `0.1 * math.log(factor) + 1.0`
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"yarn"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
0.1
*
math
.
log
(
factor
)
+
1.0
)
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"attention_factor"
:
0.5
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
self
.
assertEqual
(
attention_scale
,
0.5
)
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
# (note: adds a margin to the test for numerical stability)
factor
=
10.0
margin
=
1e-8
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
32
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_bounded_by_factor
=
[
((
default_inv_freq
[
idx
]
/
factor
)
-
margin
)
<=
yarn_inv_freq_value
<=
(
default_inv_freq
[
idx
]
+
margin
)
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertTrue
(
all
(
is_bounded_by_factor
))
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
1000
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_interpolating
=
[
yarn_inv_freq_value
<
(
default_inv_freq
[
idx
]
+
margin
)
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertFalse
(
is_interpolating
[
0
])
self
.
assertTrue
(
all
(
is_interpolating
[
1
:]))
torch
.
testing
.
assert_close
(
inv_freq
[
-
20
:],
default_inv_freq
[
-
20
:]
/
factor
)
# Check 3: numerical snapshot to avoid regressions
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
32
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_longrope_rope_numerically
(
self
):
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on `factor`
dim
=
config
.
hidden_size
//
config
.
num_attention_heads
short_factor
=
[
2.0
]
*
(
dim
//
2
)
# scaling applied when factor == 1.0
long_factor
=
torch
.
ones
(
dim
//
2
).
cumsum
(
0
).
tolist
()
# scaling applied when factor > 1.0
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))`
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"longrope"
]
max_position_embeddings
=
config
.
max_position_embeddings
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
)))
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
"attention_factor"
:
0.5
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
self
.
assertEqual
(
attention_scale
,
0.5
)
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
factor
=
1.0
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
torch
.
tensor
(
short_factor
).
to
(
torch_device
))
# Check 3: Factor > 1.0 -> long factor is applied to the default frequencies
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
torch
.
tensor
(
long_factor
).
to
(
torch_device
))
def
test_llama3_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.6596e-02
,
7.4989e-02
,
6.4938e-02
,
5.6234e-02
,
4.8697e-02
,
4.2170e-02
,
3.6517e-02
,
3.1623e-02
,
2.7384e-02
,
2.3714e-02
,
2.0535e-02
,
1.7783e-02
,
1.5399e-02
,
1.3335e-02
,
1.0730e-02
,
7.7785e-03
,
5.6009e-03
,
3.9991e-03
,
2.8248e-03
,
1.9675e-03
,
1.3449e-03
,
8.9549e-04
,
5.7363e-04
,
3.4539e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
,
1.0000e-04
,
8.6596e-05
,
7.4989e-05
,
6.4938e-05
,
5.6234e-05
,
4.8697e-05
,
4.2170e-05
,
3.6517e-05
,
3.1623e-05
,
2.7384e-05
,
2.3714e-05
,
2.0535e-05
,
1.7783e-05
,
1.5399e-05
,
1.3335e-05
,
1.1548e-05
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: `attention_factor` is always 1
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"llama3"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequences see no change, medium
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
# is considered low, medium, and high frequencies.
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_bounded_by_factor
=
[
(
default_inv_freq
[
idx
]
/
factor
)
<=
llama3_inv_freq_value
<=
default_inv_freq
[
idx
]
for
idx
,
llama3_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertTrue
(
all
(
is_bounded_by_factor
))
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
# scaled
config
.
rope_scaling
=
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
1000
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_scaled
=
[
yarn_inv_freq_value
<
default_inv_freq
[
idx
]
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)]
self
.
assertTrue
(
all
(
is_scaled
))
# Check 3: numerical snapshot to avoid regressions
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment