Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5413b898
Unverified
Commit
5413b898
authored
May 09, 2024
by
Raushan Turganbay
Committed by
GitHub
May 09, 2024
Browse files
KV cache is no longer a model attribute (#30730)
kv_cache is no longer a model attribute
parent
218f4413
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
0 additions
and
28 deletions
+0
-28
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+0
-6
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+0
-5
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+0
-5
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+0
-6
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+0
-6
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
5413b898
...
@@ -271,7 +271,6 @@ class CohereAttention(nn.Module):
...
@@ -271,7 +271,6 @@ class CohereAttention(nn.Module):
key_states
=
key_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
...
@@ -365,8 +364,6 @@ class CohereFlashAttention2(CohereAttention):
...
@@ -365,8 +364,6 @@ class CohereFlashAttention2(CohereAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
@@ -571,9 +568,6 @@ class CohereSdpaAttention(CohereAttention):
...
@@ -571,9 +568,6 @@ class CohereSdpaAttention(CohereAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
# In case static cache is used, it is an instance attribute.
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
5413b898
...
@@ -287,7 +287,6 @@ class DbrxAttention(nn.Module):
...
@@ -287,7 +287,6 @@ class DbrxAttention(nn.Module):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
...
@@ -387,8 +386,6 @@ class DbrxFlashAttention2(DbrxAttention):
...
@@ -387,8 +386,6 @@ class DbrxFlashAttention2(DbrxAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
@@ -600,8 +597,6 @@ class DbrxSdpaAttention(DbrxAttention):
...
@@ -600,8 +597,6 @@ class DbrxSdpaAttention(DbrxAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
5413b898
...
@@ -262,7 +262,6 @@ class GemmaAttention(nn.Module):
...
@@ -262,7 +262,6 @@ class GemmaAttention(nn.Module):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
...
@@ -353,8 +352,6 @@ class GemmaFlashAttention2(GemmaAttention):
...
@@ -353,8 +352,6 @@ class GemmaFlashAttention2(GemmaAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
@@ -552,8 +549,6 @@ class GemmaSdpaAttention(GemmaAttention):
...
@@ -552,8 +549,6 @@ class GemmaSdpaAttention(GemmaAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
,
seq_len
=
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
...
src/transformers/models/llama/modeling_llama.py
View file @
5413b898
...
@@ -356,7 +356,6 @@ class LlamaAttention(nn.Module):
...
@@ -356,7 +356,6 @@ class LlamaAttention(nn.Module):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
...
@@ -452,8 +451,6 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -452,8 +451,6 @@ class LlamaFlashAttention2(LlamaAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
@@ -650,9 +647,6 @@ class LlamaSdpaAttention(LlamaAttention):
...
@@ -650,9 +647,6 @@ class LlamaSdpaAttention(LlamaAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
# In case static cache is used, it is an instance attribute.
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
5413b898
...
@@ -328,7 +328,6 @@ class OlmoAttention(nn.Module):
...
@@ -328,7 +328,6 @@ class OlmoAttention(nn.Module):
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
...
@@ -419,8 +418,6 @@ class OlmoFlashAttention2(OlmoAttention):
...
@@ -419,8 +418,6 @@ class OlmoFlashAttention2(OlmoAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
@@ -624,9 +621,6 @@ class OlmoSdpaAttention(OlmoAttention):
...
@@ -624,9 +621,6 @@ class OlmoSdpaAttention(OlmoAttention):
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
# In case static cache is used, it is an instance attribute.
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment