Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
53e2fd78
Unverified
Commit
53e2fd78
authored
Sep 01, 2023
by
Joao Gante
Committed by
GitHub
Sep 01, 2023
Browse files
Falcon: Add RoPE scaling (#25878)
parent
024acd27
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
193 additions
and
23 deletions
+193
-23
src/transformers/models/deprecated/open_llama/configuration_open_llama.py
.../models/deprecated/open_llama/configuration_open_llama.py
+2
-2
src/transformers/models/falcon/configuration_falcon.py
src/transformers/models/falcon/configuration_falcon.py
+44
-0
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+109
-16
src/transformers/models/gpt_neox/configuration_gpt_neox.py
src/transformers/models/gpt_neox/configuration_gpt_neox.py
+2
-2
src/transformers/models/llama/configuration_llama.py
src/transformers/models/llama/configuration_llama.py
+2
-2
tests/models/falcon/test_modeling_falcon.py
tests/models/falcon/test_modeling_falcon.py
+34
-1
No files found.
src/transformers/models/deprecated/open_llama/configuration_open_llama.py
View file @
53e2fd78
...
...
@@ -154,14 +154,14 @@ class OpenLlamaConfig(PretrainedConfig):
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `
nam
e` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `
typ
e` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s
nam
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
f
"`rope_scaling`'s
typ
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be an float > 1, got
{
rope_scaling_factor
}
"
)
src/transformers/models/falcon/configuration_falcon.py
View file @
53e2fd78
...
...
@@ -72,6 +72,19 @@ class FalconConfig(PretrainedConfig):
instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
bias (`bool`, *optional*, defaults to `False`):
Whether to use bias on Linear layers.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with, when `alibi` is `False`. Pretrained
Falcon models with RoPE support up to 2048 tokens.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
bos_token_id (`int`, *optional*, defaults to 11):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 11):
...
...
@@ -111,6 +124,9 @@ class FalconConfig(PretrainedConfig):
multi_query
=
True
,
parallel_attn
=
True
,
bias
=
False
,
max_position_embeddings
=
2048
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
bos_token_id
=
11
,
eos_token_id
=
11
,
**
kwargs
,
...
...
@@ -135,6 +151,10 @@ class FalconConfig(PretrainedConfig):
self
.
multi_query
=
multi_query
# Ignored when new_decoder_architecture is True
self
.
parallel_attn
=
parallel_attn
self
.
bias
=
bias
self
.
max_position_embeddings
=
max_position_embeddings
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
...
...
@@ -145,3 +165,27 @@ class FalconConfig(PretrainedConfig):
@
property
def
rotary
(
self
):
return
not
self
.
alibi
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
self
.
rotary
:
raise
ValueError
(
"`rope_scaling` is not supported when `alibi` is `True`."
)
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be an float > 1, got
{
rope_scaling_factor
}
"
)
src/transformers/models/falcon/modeling_falcon.py
View file @
53e2fd78
...
...
@@ -71,20 +71,20 @@ class FalconRotaryEmbedding(nn.Module):
n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
"""
def
__init__
(
self
,
head_dim
:
int
,
base
=
10000
):
def
__init__
(
self
,
head_dim
:
int
,
base
=
10000
,
max_position_embeddings
=
2048
):
super
().
__init__
()
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
head_dim
,
2
).
float
()
/
head_dim
))
self
.
base
=
base
self
.
max_position_embeddings
=
max_position_embeddings
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
head_dim
,
2
).
float
()
/
head_dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
head_dim
=
head_dim
self
.
seq_len_cached
=
-
1
self
.
cos_cached
:
torch
.
Tensor
|
None
=
None
self
.
sin_cached
:
torch
.
Tensor
|
None
=
None
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
total_length
=
seq_len
+
past_key_values_length
if
total_length
>
self
.
seq_len_cached
:
self
.
seq_len_cached
=
total_length
t
=
torch
.
arange
(
total_length
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
seq_len_cached
=
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
...
...
@@ -97,6 +97,10 @@ class FalconRotaryEmbedding(nn.Module):
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
def
cos_sin
(
self
,
seq_len
:
int
,
past_key_values_length
:
int
,
device
=
"cpu"
,
dtype
=
torch
.
bfloat16
)
->
torch
.
Tensor
:
total_length
=
seq_len
+
past_key_values_length
if
total_length
>
self
.
seq_len_cached
:
self
.
_set_cos_sin_cache
(
total_length
,
device
,
dtype
)
return
(
self
.
cos_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
self
.
sin_cached
[:,
past_key_values_length
:
seq_len
+
past_key_values_length
],
...
...
@@ -108,6 +112,66 @@ class FalconRotaryEmbedding(nn.Module):
return
(
query
*
cos
)
+
(
rotate_half
(
query
)
*
sin
),
(
key
*
cos
)
+
(
rotate_half
(
key
)
*
sin
)
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def
__init__
(
self
,
head_dim
:
int
,
base
=
10000
,
max_position_embeddings
=
2048
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_dim
,
base
,
max_position_embeddings
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
seq_len_cached
=
seq_len
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# This line is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
t
=
t
/
self
.
scaling_factor
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
if
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
emb
=
emb
.
float
()
self
.
cos_cached
=
emb
.
cos
()[
None
,
:,
:]
self
.
sin_cached
=
emb
.
sin
()[
None
,
:,
:]
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
class
FalconDynamicNTKScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
"""
FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def
__init__
(
self
,
head_dim
:
int
,
base
=
10000
,
max_position_embeddings
=
2048
,
scaling_factor
=
1.0
):
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
head_dim
,
base
,
max_position_embeddings
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
seq_len_cached
=
seq_len
# This if block is the only difference from FalconRotaryEmbedding._set_cos_sin_cache
if
seq_len
>
self
.
max_position_embeddings
:
base
=
self
.
base
*
(
(
self
.
scaling_factor
*
seq_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
)
)
**
(
self
.
head_dim
/
(
self
.
head_dim
-
2
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
head_dim
,
2
).
float
().
to
(
device
)
/
self
.
head_dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
t
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
device
)
if
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
emb
=
emb
.
float
()
self
.
cos_cached
=
emb
.
cos
()[
None
,
:,
:]
self
.
sin_cached
=
emb
.
sin
()[
None
,
:,
:]
self
.
cos_cached
=
self
.
cos_cached
.
type
(
dtype
)
self
.
sin_cached
=
self
.
sin_cached
.
type
(
dtype
)
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
device
:
torch
.
device
,
past_key_values_length
:
int
)
->
torch
.
BoolTensor
:
...
...
@@ -191,6 +255,7 @@ class FalconAttention(nn.Module):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
...
...
@@ -203,7 +268,7 @@ class FalconAttention(nn.Module):
f
"
{
self
.
num_heads
}
)."
)
self
.
maybe_rotary
=
FalconRotaryEmbedding
(
config
.
head_dim
)
if
config
.
rotary
else
lambda
q
,
k
,
t
:
(
q
,
k
)
self
.
maybe_rotary
=
self
.
_init_rope
(
)
if
config
.
rotary
else
lambda
q
,
k
,
t
:
(
q
,
k
)
# Layer-wise attention scaling
self
.
inv_norm_factor
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
...
...
@@ -221,6 +286,34 @@ class FalconAttention(nn.Module):
self
.
attention_dropout
=
nn
.
Dropout
(
config
.
attention_dropout
)
self
.
num_kv_heads
=
config
.
num_kv_heads
if
(
self
.
new_decoder_architecture
or
not
self
.
multi_query
)
else
1
def
_init_rope
(
self
):
if
self
.
config
.
rope_scaling
is
None
:
rotary_emb
=
FalconRotaryEmbedding
(
self
.
head_dim
,
base
=
self
.
config
.
rope_theta
,
max_position_embeddings
=
self
.
config
.
max_position_embeddings
,
)
else
:
scaling_type
=
self
.
config
.
rope_scaling
[
"type"
]
scaling_factor
=
self
.
config
.
rope_scaling
[
"factor"
]
if
scaling_type
==
"linear"
:
rotary_emb
=
FalconLinearScalingRotaryEmbedding
(
self
.
head_dim
,
base
=
self
.
config
.
rope_theta
,
max_position_embeddings
=
self
.
config
.
max_position_embeddings
,
scaling_factor
=
scaling_factor
,
)
elif
scaling_type
==
"dynamic"
:
rotary_emb
=
FalconDynamicNTKScalingRotaryEmbedding
(
self
.
head_dim
,
base
=
self
.
config
.
rope_theta
,
max_position_embeddings
=
self
.
config
.
max_position_embeddings
,
scaling_factor
=
scaling_factor
,
)
else
:
raise
ValueError
(
f
"Unknown RoPE scaling type
{
scaling_type
}
"
)
return
rotary_emb
def
_split_heads
(
self
,
fused_qkv
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
...
...
src/transformers/models/gpt_neox/configuration_gpt_neox.py
View file @
53e2fd78
...
...
@@ -163,14 +163,14 @@ class GPTNeoXConfig(PretrainedConfig):
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `
nam
e` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `
typ
e` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s
nam
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
f
"`rope_scaling`'s
typ
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be an float > 1, got
{
rope_scaling_factor
}
"
)
src/transformers/models/llama/configuration_llama.py
View file @
53e2fd78
...
...
@@ -165,14 +165,14 @@ class LlamaConfig(PretrainedConfig):
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, `
nam
e` and `factor`, "
"`rope_scaling` must be a dictionary with with two fields, `
typ
e` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s
nam
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
f
"`rope_scaling`'s
typ
e field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be an float > 1, got
{
rope_scaling_factor
}
"
)
tests/models/falcon/test_modeling_falcon.py
View file @
53e2fd78
...
...
@@ -17,7 +17,9 @@
import
unittest
from
transformers
import
AutoTokenizer
,
FalconConfig
,
is_torch_available
from
parameterized
import
parameterized
from
transformers
import
AutoTokenizer
,
FalconConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
...generation.test_utils
import
GenerationTesterMixin
...
...
@@ -410,6 +412,37 @@ class FalconModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
past_kv
[
i
][
1
].
shape
,
(
batch_size
,
num_attention_heads
,
seq_length
,
per_head_embed_dim
)
)
@
parameterized
.
expand
([(
"linear"
,),
(
"dynamic"
,)])
def
test_model_rope_scaling
(
self
,
scaling_type
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
short_input
=
ids_tensor
([
1
,
10
],
config
.
vocab_size
)
long_input
=
ids_tensor
([
1
,
int
(
config
.
max_position_embeddings
*
1.5
)],
config
.
vocab_size
)
set_seed
(
42
)
# Fixed seed at init time so the two models get the same random weights
original_model
=
FalconModel
(
config
)
original_model
.
to
(
torch_device
)
original_model
.
eval
()
original_short_output
=
original_model
(
short_input
).
last_hidden_state
original_long_output
=
original_model
(
long_input
).
last_hidden_state
set_seed
(
42
)
# Fixed seed at init time so the two models get the same random weights
config
.
rope_scaling
=
{
"type"
:
scaling_type
,
"factor"
:
10.0
}
scaled_model
=
FalconModel
(
config
)
scaled_model
.
to
(
torch_device
)
scaled_model
.
eval
()
scaled_short_output
=
scaled_model
(
short_input
).
last_hidden_state
scaled_long_output
=
scaled_model
(
long_input
).
last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if
scaling_type
==
"dynamic"
:
self
.
assertTrue
(
torch
.
allclose
(
original_short_output
,
scaled_short_output
,
atol
=
1e-5
))
else
:
self
.
assertFalse
(
torch
.
allclose
(
original_short_output
,
scaled_short_output
,
atol
=
1e-5
))
# The output should be different for long inputs
self
.
assertFalse
(
torch
.
allclose
(
original_long_output
,
scaled_long_output
,
atol
=
1e-5
))
@
require_torch
class
FalconLanguageGenerationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment