Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3994fa5b
Unverified
Commit
3994fa5b
authored
Feb 21, 2024
by
Joao Gante
Committed by
GitHub
Feb 21, 2024
Browse files
🚨
Llama: update rope scaling to match static cache changes (#29143)
parent
1a77f07f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
44 deletions
+38
-44
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
...rmers/models/deprecated/open_llama/modeling_open_llama.py
+2
-2
src/transformers/models/falcon/modeling_falcon.py
src/transformers/models/falcon/modeling_falcon.py
+4
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+26
-33
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+2
-2
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+2
-2
src/transformers/models/stablelm/modeling_stablelm.py
src/transformers/models/stablelm/modeling_stablelm.py
+2
-2
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+0
-1
No files found.
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
View file @
3994fa5b
...
@@ -100,7 +100,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
...
@@ -100,7 +100,7 @@ class OpenLlamaRotaryEmbedding(nn.Module):
)
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
LinearScalingRotaryEmbedding with
Llama
->OpenLlama
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
LinearScalingRotaryEmbedding with
Falcon
->OpenLlama
class
OpenLlamaLinearScalingRotaryEmbedding
(
OpenLlamaRotaryEmbedding
):
class
OpenLlamaLinearScalingRotaryEmbedding
(
OpenLlamaRotaryEmbedding
):
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...
@@ -120,7 +120,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
...
@@ -120,7 +120,7 @@ class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
DynamicNTKScalingRotaryEmbedding with
Llama
->OpenLlama
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
DynamicNTKScalingRotaryEmbedding with
Falcon
->OpenLlama
class
OpenLlamaDynamicNTKScalingRotaryEmbedding
(
OpenLlamaRotaryEmbedding
):
class
OpenLlamaDynamicNTKScalingRotaryEmbedding
(
OpenLlamaRotaryEmbedding
):
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...
...
src/transformers/models/falcon/modeling_falcon.py
View file @
3994fa5b
...
@@ -167,7 +167,8 @@ class FalconRotaryEmbedding(nn.Module):
...
@@ -167,7 +167,8 @@ class FalconRotaryEmbedding(nn.Module):
)
)
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
class
FalconLinearScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""FalconRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...
@@ -187,7 +188,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
...
@@ -187,7 +188,8 @@ class FalconLinearScalingRotaryEmbedding(FalconRotaryEmbedding):
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Falcon
# TODO @joao no longer copied from LLama after static cache, fix me (copied -> Copied)
class
FalconDynamicNTKScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
class
FalconDynamicNTKScalingRotaryEmbedding
(
FalconRotaryEmbedding
):
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""FalconRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...
...
src/transformers/models/llama/modeling_llama.py
View file @
3994fa5b
...
@@ -94,7 +94,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
...
@@ -94,7 +94,6 @@ ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
class
LlamaRotaryEmbedding
(
nn
.
Module
):
class
LlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
base
=
base
...
@@ -118,6 +117,9 @@ class LlamaRotaryEmbedding(nn.Module):
...
@@ -118,6 +117,9 @@ class LlamaRotaryEmbedding(nn.Module):
return
self
.
_cos_cached
return
self
.
_cos_cached
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
if
seq_len
is
not
None
:
logger
.
warning_once
(
"The `seq_len` argument is deprecated and unused. It will be removed in v4.40."
)
# x: [bs, num_attention_heads, seq_len, head_size]
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
inv_freq_expanded
=
self
.
inv_freq
[
None
,
:,
None
].
float
().
expand
(
position_ids
.
shape
[
0
],
-
1
,
1
)
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
position_ids_expanded
=
position_ids
[:,
None
,
:].
float
()
...
@@ -138,16 +140,11 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
...
@@ -138,16 +140,11 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
self
.
scaling_factor
=
scaling_factor
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
None
):
self
.
max_seq_len_cached
=
seq_len
# difference to the original RoPE: a scaling factor is aplied to the position ids
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
torch
.
int64
).
type_as
(
self
.
inv_freq
)
position_ids
=
position_ids
.
float
()
/
self
.
scaling_factor
t
=
t
/
self
.
scaling_factor
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
return
cos
,
sin
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
"cos_cached"
,
emb
.
cos
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
class
LlamaDynamicNTKScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
class
LlamaDynamicNTKScalingRotaryEmbedding
(
LlamaRotaryEmbedding
):
...
@@ -157,23 +154,20 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
...
@@ -157,23 +154,20 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
self
.
scaling_factor
=
scaling_factor
self
.
scaling_factor
=
scaling_factor
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
super
().
__init__
(
dim
,
max_position_embeddings
,
base
,
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtyp
e
):
def
forward
(
self
,
x
,
position_ids
,
seq_len
=
Non
e
):
self
.
max_seq_len_cached
=
seq_len
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len
=
torch
.
max
(
position_ids
)
+
1
if
seq_len
>
self
.
max_position_embeddings
:
if
seq_len
>
self
.
max_position_embeddings
:
base
=
self
.
base
*
(
base
=
self
.
base
*
(
(
self
.
scaling_factor
*
seq_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
)
(
self
.
scaling_factor
*
seq_len
/
self
.
max_position_embeddings
)
-
(
self
.
scaling_factor
-
1
)
)
**
(
self
.
dim
/
(
self
.
dim
-
2
))
)
**
(
self
.
dim
/
(
self
.
dim
-
2
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
device
)
/
self
.
dim
))
inv_freq
=
1.0
/
(
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
int64
).
float
().
to
(
x
.
device
)
/
self
.
dim
)
)
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
torch
.
int64
).
type_as
(
self
.
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
# TODO joao: this may break with compilation
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
)
cos
,
sin
=
super
().
forward
(
x
,
position_ids
,
seq_len
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
return
cos
,
sin
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
"cos_cached"
,
emb
.
cos
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
def
rotate_half
(
x
):
def
rotate_half
(
x
):
...
@@ -183,7 +177,7 @@ def rotate_half(x):
...
@@ -183,7 +177,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
=
None
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
Args:
...
@@ -191,9 +185,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
...
@@ -191,9 +185,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
k (`torch.Tensor`): The key tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
position_ids (`torch.Tensor`, *optional*):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
Deprecated and unused.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
...
@@ -360,8 +353,8 @@ class LlamaAttention(nn.Module):
...
@@ -360,8 +353,8 @@ class LlamaAttention(nn.Module):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; position_ids needed for the static cache
...
@@ -447,8 +440,8 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -447,8 +440,8 @@ class LlamaFlashAttention2(LlamaAttention):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
...
@@ -645,8 +638,8 @@ class LlamaSdpaAttention(LlamaAttention):
...
@@ -645,8 +638,8 @@ class LlamaSdpaAttention(LlamaAttention):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
3994fa5b
...
@@ -77,7 +77,7 @@ class PersimmonRotaryEmbedding(nn.Module):
...
@@ -77,7 +77,7 @@ class PersimmonRotaryEmbedding(nn.Module):
)
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
LinearScalingRotaryEmbedding with
Llama
->Persimmon
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
LinearScalingRotaryEmbedding with
Falcon
->Persimmon
class
PersimmonLinearScalingRotaryEmbedding
(
PersimmonRotaryEmbedding
):
class
PersimmonLinearScalingRotaryEmbedding
(
PersimmonRotaryEmbedding
):
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...
@@ -97,7 +97,7 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
...
@@ -97,7 +97,7 @@ class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
DynamicNTKScalingRotaryEmbedding with
Llama
->Persimmon
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
DynamicNTKScalingRotaryEmbedding with
Falcon
->Persimmon
class
PersimmonDynamicNTKScalingRotaryEmbedding
(
PersimmonRotaryEmbedding
):
class
PersimmonDynamicNTKScalingRotaryEmbedding
(
PersimmonRotaryEmbedding
):
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...
...
src/transformers/models/phi/modeling_phi.py
View file @
3994fa5b
...
@@ -120,7 +120,7 @@ class PhiRotaryEmbedding(nn.Module):
...
@@ -120,7 +120,7 @@ class PhiRotaryEmbedding(nn.Module):
)
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
LinearScalingRotaryEmbedding with
Llama
->Phi
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
LinearScalingRotaryEmbedding with
Falcon
->Phi
class
PhiLinearScalingRotaryEmbedding
(
PhiRotaryEmbedding
):
class
PhiLinearScalingRotaryEmbedding
(
PhiRotaryEmbedding
):
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...
@@ -140,7 +140,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
...
@@ -140,7 +140,7 @@ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
DynamicNTKScalingRotaryEmbedding with
Llama
->Phi
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
DynamicNTKScalingRotaryEmbedding with
Falcon
->Phi
class
PhiDynamicNTKScalingRotaryEmbedding
(
PhiRotaryEmbedding
):
class
PhiDynamicNTKScalingRotaryEmbedding
(
PhiRotaryEmbedding
):
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...
...
src/transformers/models/stablelm/modeling_stablelm.py
View file @
3994fa5b
...
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
...
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
)
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
LinearScalingRotaryEmbedding with
Llama
->StableLm
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
LinearScalingRotaryEmbedding with
Falcon
->StableLm
class
StableLmLinearScalingRotaryEmbedding
(
StableLmRotaryEmbedding
):
class
StableLmLinearScalingRotaryEmbedding
(
StableLmRotaryEmbedding
):
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
...
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
...
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
().
to
(
dtype
),
persistent
=
False
)
# Copied from transformers.models.
llama
.modeling_
llama.Llama
DynamicNTKScalingRotaryEmbedding with
Llama
->StableLm
# Copied from transformers.models.
falcon
.modeling_
falcon.Falcon
DynamicNTKScalingRotaryEmbedding with
Falcon
->StableLm
class
StableLmDynamicNTKScalingRotaryEmbedding
(
StableLmRotaryEmbedding
):
class
StableLmDynamicNTKScalingRotaryEmbedding
(
StableLmRotaryEmbedding
):
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
...
...
tests/models/llama/test_modeling_llama.py
View file @
3994fa5b
...
@@ -362,7 +362,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -362,7 +362,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
pass
pass
@
parameterized
.
expand
([(
"linear"
,),
(
"dynamic"
,)])
@
parameterized
.
expand
([(
"linear"
,),
(
"dynamic"
,)])
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
def
test_model_rope_scaling
(
self
,
scaling_type
):
def
test_model_rope_scaling
(
self
,
scaling_type
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
short_input
=
ids_tensor
([
1
,
10
],
config
.
vocab_size
)
short_input
=
ids_tensor
([
1
,
10
],
config
.
vocab_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment