Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1749841a
Unverified
Commit
1749841a
authored
Jun 03, 2024
by
Arthur
Committed by
GitHub
Jun 03, 2024
Browse files
[`GemmaModel`] fix small typo (#31202)
* fixes * fix-copies
parent
39b2ff69
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
14 deletions
+18
-14
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+4
-4
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+3
-3
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+1
-2
utils/diff_model_converter.py
utils/diff_model_converter.py
+10
-5
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
1749841a
...
@@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
...
@@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
if
not
output_attentions
:
...
@@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
...
@@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
@@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
return_legacy_cache
=
False
# noqa: F841
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
return_legacy_cache
=
True
# noqa: F841
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
...
src/transformers/models/llama/modeling_llama.py
View file @
1749841a
...
@@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
...
@@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
)
if
self
.
config
.
pretraining_tp
>
1
:
if
self
.
config
.
pretraining_tp
>
1
:
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
...
@@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
if
not
output_attentions
:
...
@@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
...
@@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
1749841a
...
@@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
...
@@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
utils/diff_model_converter.py
View file @
1749841a
...
@@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
...
@@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
Helper method to update the body by removing duplicates before adding new statements.
Helper method to update the body by removing duplicates before adding new statements.
"""
"""
deduplicated_new_body
=
[]
deduplicated_new_body
=
[]
existing_nodes
=
{
existing_nodes
=
set
()
self
.
python_module
.
code_for_node
(
node
).
strip
()
for
node
in
new_statements
if
isinstance
(
node
,
cst
.
CSTNode
)
for
node
in
new_statements
:
}
code
=
self
.
python_module
.
code_for_node
(
node
)
comment_less_code
=
re
.
sub
(
r
"#.*"
,
""
,
code
).
strip
()
comment_less_code
=
re
.
sub
(
r
"\ *\n"
,
"
\n
"
,
comment_less_code
).
strip
()
existing_nodes
.
add
(
comment_less_code
)
for
stmt
in
existing_body
:
for
stmt
in
existing_body
:
if
self
.
python_module
.
code_for_node
(
stmt
).
strip
()
not
in
existing_nodes
:
comment_less_code
=
re
.
sub
(
r
"#.*"
,
""
,
self
.
python_module
.
code_for_node
(
stmt
)).
strip
()
comment_less_code
=
re
.
sub
(
r
"\ *\n"
,
"
\n
"
,
comment_less_code
).
strip
()
if
comment_less_code
not
in
existing_nodes
:
if
m
.
matches
(
stmt
,
DOCSTRING_NODE
)
and
self
.
has_docstring
:
if
m
.
matches
(
stmt
,
DOCSTRING_NODE
)
and
self
.
has_docstring
:
continue
continue
deduplicated_new_body
.
append
(
stmt
)
deduplicated_new_body
.
append
(
stmt
)
...
@@ -542,7 +547,7 @@ if __name__ == "__main__":
...
@@ -542,7 +547,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--files_to_parse"
,
"--files_to_parse"
,
default
=
[
"
/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py
"
],
default
=
[
"
all
"
],
nargs
=
"+"
,
nargs
=
"+"
,
help
=
"A list of `diff_xxxx` files that should be converted to single model file"
,
help
=
"A list of `diff_xxxx` files that should be converted to single model file"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment