Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b75ad566
Unverified
Commit
b75ad566
authored
Jul 31, 2024
by
Joao Gante
Committed by
GitHub
Jul 31, 2024
Browse files
Llama 3.1: Fix incorrect `inv_freq` assignment (#32330)
fix
💩
parent
7f552e28
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
5 deletions
+33
-5
src/transformers/modeling_rope_utils.py
src/transformers/modeling_rope_utils.py
+4
-4
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+29
-1
No files found.
src/transformers/modeling_rope_utils.py
View file @
b75ad566
...
@@ -328,14 +328,14 @@ def _compute_llama3_parameters(
...
@@ -328,14 +328,14 @@ def _compute_llama3_parameters(
wavelen
=
2
*
math
.
pi
/
inv_freq
wavelen
=
2
*
math
.
pi
/
inv_freq
# wavelen < high_freq_wavelen: do nothing
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
# wavelen > low_freq_wavelen: divide by factor
inv_freq_
new
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
inv_freq_
llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
inv_freq
/
factor
,
inv_freq
)
# otherwise: interpolate between the two, using a smooth factor
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_
new
/
factor
+
smooth_factor
*
inv_freq_
new
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_
llama
/
factor
+
smooth_factor
*
inv_freq_
llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_
new
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_
new
)
inv_freq_
llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_
llama
)
return
inv_freq
,
attention_factor
return
inv_freq
_llama
,
attention_factor
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
...
...
tests/models/llama/test_modeling_llama.py
View file @
b75ad566
...
@@ -22,7 +22,7 @@ import pytest
...
@@ -22,7 +22,7 @@ import pytest
from
packaging
import
version
from
packaging
import
version
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
transformers
import
LlamaConfig
,
StaticCache
,
is_torch_available
,
set_seed
from
transformers
import
AutoTokenizer
,
LlamaConfig
,
StaticCache
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
require_bitsandbytes
,
require_bitsandbytes
,
require_flash_attn
,
require_flash_attn
,
...
@@ -718,6 +718,34 @@ class LlamaIntegrationTest(unittest.TestCase):
...
@@ -718,6 +718,34 @@ class LlamaIntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
# 8 is for A100 / A10 and 7 for T4
cls
.
cuda_compute_capability_major_version
=
torch
.
cuda
.
get_device_capability
()[
0
]
cls
.
cuda_compute_capability_major_version
=
torch
.
cuda
.
get_device_capability
()[
0
]
@
slow
@
require_read_token
def
test_llama_3_1_hard
(
self
):
"""
An integration test for llama 3.1. It tests against a long output to ensure the subtle numerical differences
from llama 3.1.'s RoPE can be detected
"""
EXPECTED_TEXT
=
(
"Tell me about the french revolution. The french revolution was a period of radical social and political "
"upheaval in France that lasted from 1789 until 1799. It was a time of great change and upheaval, marked "
"by the overthrow of the monarchy, the rise of the middle class, and the eventual establishment of the "
"First French Republic.
\n
The revolution began in 1789 with the Estates-General, a representative "
"assembly that had not met since 1614. The Third Estate, which represented the common people, "
"demanded greater representation and eventually broke away to form the National Assembly. This marked "
"the beginning of the end of the absolute monarchy and the rise of the middle class.
\n
"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
)
model
=
LlamaForCausalLM
.
from_pretrained
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
bfloat16
)
input_text
=
[
"Tell me about the french revolution."
]
model_inputs
=
tokenizer
(
input_text
,
return_tensors
=
"pt"
).
to
(
model
.
device
)
generated_ids
=
model
.
generate
(
**
model_inputs
,
max_new_tokens
=
128
,
do_sample
=
False
)
generated_text
=
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
)
self
.
assertEqual
(
generated_text
,
EXPECTED_TEXT
)
@
slow
@
slow
@
require_read_token
@
require_read_token
def
test_model_7b_logits_bf16
(
self
):
def
test_model_7b_logits_bf16
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment