Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
083e13b7
Unverified
Commit
083e13b7
authored
Aug 02, 2024
by
Joao Gante
Committed by
GitHub
Aug 02, 2024
Browse files
RoPE: Add numerical tests
✨
(#32380)
tests! :D
parent
2af199c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
329 additions
and
5 deletions
+329
-5
src/transformers/modeling_rope_utils.py
src/transformers/modeling_rope_utils.py
+9
-4
tests/utils/test_modeling_rope_utils.py
tests/utils/test_modeling_rope_utils.py
+320
-1
No files found.
src/transformers/modeling_rope_utils.py
View file @
083e13b7
...
@@ -150,7 +150,7 @@ def _compute_dynamic_ntk_parameters(
...
@@ -150,7 +150,7 @@ def _compute_dynamic_ntk_parameters(
attention_factor
=
1.0
# Unused in this type of RoPE
attention_factor
=
1.0
# Unused in this type of RoPE
# seq_len: default to max_position_embeddings, e.g. at init time
# seq_len: default to max_position_embeddings, e.g. at init time
seq_len
=
seq_len
if
seq_len
is
not
None
else
max_position_embeddings
seq_len
=
seq_len
if
seq_len
is
not
None
and
seq_len
>
max_position_embeddings
else
max_position_embeddings
# Compute the inverse frequencies
# Compute the inverse frequencies
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
base
=
base
*
((
factor
*
seq_len
/
max_position_embeddings
)
-
(
factor
-
1
))
**
(
dim
/
(
dim
-
2
))
...
@@ -210,7 +210,7 @@ def _compute_yarn_parameters(
...
@@ -210,7 +210,7 @@ def _compute_yarn_parameters(
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
high
=
math
.
ceil
(
find_correction_dim
(
high_rot
,
dim
,
base
,
max_position_embeddings
))
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
return
max
(
low
,
0
),
min
(
high
,
dim
-
1
)
def
linear_ramp_
mask
(
min
,
max
,
dim
):
def
linear_ramp_
factor
(
min
,
max
,
dim
):
if
min
==
max
:
if
min
==
max
:
max
+=
0.001
# Prevent singularity
max
+=
0.001
# Prevent singularity
...
@@ -218,6 +218,8 @@ def _compute_yarn_parameters(
...
@@ -218,6 +218,8 @@ def _compute_yarn_parameters(
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
return
ramp_func
# Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
# to expand the possible context length. In other words, interpolation = apply scaling factor.
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
pos_freqs
=
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
inv_freq_interpolation
=
1.0
/
(
factor
*
pos_freqs
)
...
@@ -225,8 +227,11 @@ def _compute_yarn_parameters(
...
@@ -225,8 +227,11 @@ def _compute_yarn_parameters(
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
low
,
high
=
find_correction_range
(
beta_fast
,
beta_slow
,
dim
,
base
,
max_position_embeddings
)
# Get n-dimensional rotational scaling corrected for extrapolation
# Get n-dimensional rotational scaling corrected for extrapolation
inv_freq_mask
=
1
-
linear_ramp_mask
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq_extrapolation_factor
=
1
-
linear_ramp_factor
(
low
,
high
,
dim
//
2
).
float
().
to
(
device
)
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_extrapolation_factor
)
+
inv_freq_extrapolation
*
inv_freq_extrapolation_factor
)
return
inv_freq
,
attention_factor
return
inv_freq
,
attention_factor
...
...
tests/utils/test_modeling_rope_utils.py
View file @
083e13b7
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
import
math
import
unittest
import
unittest
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
...
@@ -116,5 +117,323 @@ class RopeTest(unittest.TestCase):
...
@@ -116,5 +117,323 @@ class RopeTest(unittest.TestCase):
kwargs_freqs
=
rope_fn
(
**
rope_kwargs
,
device
=
device
)[
0
]
kwargs_freqs
=
rope_fn
(
**
rope_kwargs
,
device
=
device
)[
0
]
torch
.
testing
.
assert_close
(
config_freqs
,
kwargs_freqs
)
torch
.
testing
.
assert_close
(
config_freqs
,
kwargs_freqs
)
def
test_default_rope_numerically
(
self
):
# Note: some RoPE scaling methods start off by calling the default RoPE frequencies. If this test fails, then
# multiple RoPE strategies will fail.
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.6596e-02
,
7.4989e-02
,
6.4938e-02
,
5.6234e-02
,
4.8697e-02
,
4.2170e-02
,
3.6517e-02
,
3.1623e-02
,
2.7384e-02
,
2.3714e-02
,
2.0535e-02
,
1.7783e-02
,
1.5399e-02
,
1.3335e-02
,
1.1548e-02
,
1.0000e-02
,
8.6596e-03
,
7.4989e-03
,
6.4938e-03
,
5.6234e-03
,
4.8697e-03
,
4.2170e-03
,
3.6517e-03
,
3.1623e-03
,
2.7384e-03
,
2.3714e-03
,
2.0535e-03
,
1.7783e-03
,
1.5399e-03
,
1.3335e-03
,
1.1548e-03
,
1.0000e-03
,
8.6596e-04
,
7.4989e-04
,
6.4938e-04
,
5.6234e-04
,
4.8697e-04
,
4.2170e-04
,
3.6517e-04
,
3.1623e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
],
device
=
torch_device
)
# fmt: on
# TODO(joao): numerical checks for the different RoPE fns
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for default RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_linear_rope_numerically
(
self
):
# This is a linear scaling strategy, the **frequencies** are scaled linearly with respect to the default
# frequencies (= the inverse frequencies are scaled **inversely**)
config
=
LlamaConfig
()
default_rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
default_rope_fn
(
config
=
config
,
device
=
torch_device
)
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"linear"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"linear"
,
"factor"
:
factor
}
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for linear RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
factor
)
def
test_dynamic_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.0931e-01
,
6.5498e-01
,
5.3008e-01
,
4.2900e-01
,
3.4720e-01
,
2.8099e-01
,
2.2741e-01
,
1.8404e-01
,
1.4895e-01
,
1.2055e-01
,
9.7558e-02
,
7.8955e-02
,
6.3899e-02
,
5.1714e-02
,
4.1853e-02
,
3.3872e-02
,
2.7413e-02
,
2.2185e-02
,
1.7955e-02
,
1.4531e-02
,
1.1760e-02
,
9.5176e-03
,
7.7027e-03
,
6.2339e-03
,
5.0451e-03
,
4.0831e-03
,
3.3045e-03
,
2.6744e-03
,
2.1644e-03
,
1.7517e-03
,
1.4176e-03
,
1.1473e-03
,
9.2852e-04
,
7.5146e-04
,
6.0817e-04
,
4.9220e-04
,
3.9834e-04
,
3.2238e-04
,
2.6091e-04
,
2.1115e-04
,
1.7089e-04
,
1.3830e-04
,
1.1193e-04
,
9.0585e-05
,
7.3312e-05
,
5.9332e-05
,
4.8018e-05
,
3.8861e-05
,
3.1451e-05
,
2.5453e-05
,
2.0600e-05
,
1.6672e-05
,
1.3492e-05
,
1.0920e-05
,
8.8374e-06
,
7.1522e-06
,
5.7883e-06
,
4.6845e-06
,
3.7912e-06
,
3.0683e-06
,
2.4832e-06
,
2.0097e-06
,
1.6265e-06
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: this is a dynamic scaling strategy, it will not scale unless we provide `seq_len` larger than the
# model's original training sequence length
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"dynamic"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"dynamic"
,
"factor"
:
factor
}
inv_freq
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# attention scale is always 1 for dynamic RoPE
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
)
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
)
# Check 2: if we provide `seq_len` larger than the model's original training sequence length, the frequencies
# will scale up (i.e., the inverse frequencies will scale down).
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"dynamic"
,
"factor"
:
factor
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
16384
)
with
self
.
assertRaises
(
AssertionError
):
# It is NOT a linear factor
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
factor
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_yarn_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.3479e-02
,
6.9590e-02
,
5.7925e-02
,
4.8136e-02
,
3.9931e-02
,
3.3061e-02
,
2.7315e-02
,
2.2515e-02
,
1.8512e-02
,
1.5177e-02
,
1.2403e-02
,
1.0101e-02
,
8.1924e-03
,
6.6143e-03
,
5.3120e-03
,
4.2400e-03
,
3.3599e-03
,
2.6396e-03
,
2.0520e-03
,
1.5746e-03
,
1.1882e-03
,
8.7713e-04
,
6.2810e-04
,
4.3007e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
,
1.0000e-04
,
8.6596e-05
,
7.4989e-05
,
6.4938e-05
,
5.6234e-05
,
4.8697e-05
,
4.2170e-05
,
3.6517e-05
,
3.1623e-05
,
2.7384e-05
,
2.3714e-05
,
2.0535e-05
,
1.7783e-05
,
1.5399e-05
,
1.3335e-05
,
1.1548e-05
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `0.1 * math.log(factor) + 1.0`
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"yarn"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
0.1
*
math
.
log
(
factor
)
+
1.0
)
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"attention_factor"
:
0.5
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
self
.
assertEqual
(
attention_scale
,
0.5
)
# Check 2: based on `beta_fast` and `beta_slow`, the frequencies will be scaled between 1 and `factor`.
# Increasing `beta_fast` will make RoPE more interpolative (apply scaling), and the other way around.
# `beta_slow` behaves the opposite way. Remember: `beta_fast` > `beta_slow`
# (note: adds a margin to the test for numerical stability)
factor
=
10.0
margin
=
1e-8
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
32
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_bounded_by_factor
=
[
((
default_inv_freq
[
idx
]
/
factor
)
-
margin
)
<=
yarn_inv_freq_value
<=
(
default_inv_freq
[
idx
]
+
margin
)
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertTrue
(
all
(
is_bounded_by_factor
))
# super high beta_fast = interpolation (i.e. scaling) in all but the first inverse frequency. The last ~20
# values (empirically checked for `beta_fast` = 1000) should be very small to linear scaling
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
1000
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_interpolating
=
[
yarn_inv_freq_value
<
(
default_inv_freq
[
idx
]
+
margin
)
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertFalse
(
is_interpolating
[
0
])
self
.
assertTrue
(
all
(
is_interpolating
[
1
:]))
torch
.
testing
.
assert_close
(
inv_freq
[
-
20
:],
default_inv_freq
[
-
20
:]
/
factor
)
# Check 3: numerical snapshot to avoid regressions
config
.
rope_scaling
=
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"beta_fast"
:
32
,
"beta_slow"
:
1
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
def
test_longrope_rope_numerically
(
self
):
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
# longrope applies scaling on EACH inv frequency, `short_factor` or `long_factor`, depending on `factor`
dim
=
config
.
hidden_size
//
config
.
num_attention_heads
short_factor
=
[
2.0
]
*
(
dim
//
2
)
# scaling applied when factor == 1.0
long_factor
=
torch
.
ones
(
dim
//
2
).
cumsum
(
0
).
tolist
()
# scaling applied when factor > 1.0
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: according to the paper, if `attention_factor` is not specified, then it has a specific default --
# `math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))`
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"longrope"
]
max_position_embeddings
=
config
.
max_position_embeddings
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
math
.
sqrt
(
1
+
math
.
log
(
factor
)
/
math
.
log
(
max_position_embeddings
)))
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
"attention_factor"
:
0.5
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
,
seq_len
=
1
)
self
.
assertEqual
(
attention_scale
,
0.5
)
# Check 2: Factor == 1.0 -> short factor is applied to the default frequencies
factor
=
1.0
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
torch
.
tensor
(
short_factor
).
to
(
torch_device
))
# Check 3: Factor > 1.0 -> long factor is applied to the default frequencies
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"longrope"
,
"factor"
:
factor
,
"short_factor"
:
short_factor
,
"long_factor"
:
long_factor
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
default_inv_freq
/
torch
.
tensor
(
long_factor
).
to
(
torch_device
))
def
test_llama3_rope_numerically
(
self
):
# fmt: off
EXPECTED_INV_FREQ
=
torch
.
tensor
(
[
1.0000e+00
,
8.6596e-01
,
7.4989e-01
,
6.4938e-01
,
5.6234e-01
,
4.8697e-01
,
4.2170e-01
,
3.6517e-01
,
3.1623e-01
,
2.7384e-01
,
2.3714e-01
,
2.0535e-01
,
1.7783e-01
,
1.5399e-01
,
1.3335e-01
,
1.1548e-01
,
1.0000e-01
,
8.6596e-02
,
7.4989e-02
,
6.4938e-02
,
5.6234e-02
,
4.8697e-02
,
4.2170e-02
,
3.6517e-02
,
3.1623e-02
,
2.7384e-02
,
2.3714e-02
,
2.0535e-02
,
1.7783e-02
,
1.5399e-02
,
1.3335e-02
,
1.0730e-02
,
7.7785e-03
,
5.6009e-03
,
3.9991e-03
,
2.8248e-03
,
1.9675e-03
,
1.3449e-03
,
8.9549e-04
,
5.7363e-04
,
3.4539e-04
,
2.7384e-04
,
2.3714e-04
,
2.0535e-04
,
1.7783e-04
,
1.5399e-04
,
1.3335e-04
,
1.1548e-04
,
1.0000e-04
,
8.6596e-05
,
7.4989e-05
,
6.4938e-05
,
5.6234e-05
,
4.8697e-05
,
4.2170e-05
,
3.6517e-05
,
3.1623e-05
,
2.7384e-05
,
2.3714e-05
,
2.0535e-05
,
1.7783e-05
,
1.5399e-05
,
1.3335e-05
,
1.1548e-05
],
device
=
torch_device
)
# fmt: on
# input sanity checks: if these change, the output will also change
config
=
LlamaConfig
()
self
.
assertEqual
(
config
.
rope_scaling
,
None
)
self
.
assertEqual
(
config
.
hidden_size
,
4096
)
self
.
assertEqual
(
config
.
num_attention_heads
,
32
)
self
.
assertEqual
(
config
.
rope_theta
,
10000.0
)
self
.
assertFalse
(
hasattr
(
config
,
"partial_rotary_factor"
))
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"default"
]
default_inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
# Check 1: `attention_factor` is always 1
rope_fn
=
ROPE_INIT_FUNCTIONS
[
"llama3"
]
for
factor
in
(
2.0
,
10.0
,
20.0
):
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
_
,
attention_scale
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
self
.
assertEqual
(
attention_scale
,
1.0
)
# Check 2: based on `low_freq_factor` and `high_freq_factor`, the frequencies will be scaled between 1 and
# `factor` (similar to yarn). Low frequencies get scaled by `factor`, high frequences see no change, medium
# frequencies are scaled by a value in between. Changing `low_freq_factor` and `high_freq_factor` changes what
# is considered low, medium, and high frequencies.
factor
=
10.0
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_bounded_by_factor
=
[
(
default_inv_freq
[
idx
]
/
factor
)
<=
llama3_inv_freq_value
<=
default_inv_freq
[
idx
]
for
idx
,
llama3_inv_freq_value
in
enumerate
(
inv_freq
)
]
self
.
assertTrue
(
all
(
is_bounded_by_factor
))
# if we change `high_freq_factor` to a very high value, none is considered high-frequency -> ALL values will be
# scaled
config
.
rope_scaling
=
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
1000
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
is_scaled
=
[
yarn_inv_freq_value
<
default_inv_freq
[
idx
]
for
idx
,
yarn_inv_freq_value
in
enumerate
(
inv_freq
)]
self
.
assertTrue
(
all
(
is_scaled
))
# Check 3: numerical snapshot to avoid regressions
config
.
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
2048
,
"low_freq_factor"
:
1
,
"high_freq_factor"
:
4
,
}
inv_freq
,
_
=
rope_fn
(
config
=
config
,
device
=
torch_device
)
torch
.
testing
.
assert_close
(
inv_freq
,
EXPECTED_INV_FREQ
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment