Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1749841a
Unverified
Commit
1749841a
authored
Jun 03, 2024
by
Arthur
Committed by
GitHub
Jun 03, 2024
Browse files
[`GemmaModel`] fix small typo (#31202)
* fixes * fix-copies
parent
39b2ff69
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
14 deletions
+18
-14
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+4
-4
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+3
-3
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+1
-2
utils/diff_model_converter.py
utils/diff_model_converter.py
+10
-5
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
1749841a
...
@@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
...
@@ -408,7 +408,7 @@ class GemmaFlashAttention2(GemmaAttention):
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
if
not
output_attentions
:
...
@@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
...
@@ -594,7 +594,7 @@ class GemmaSdpaAttention(GemmaAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
@@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -866,9 +866,9 @@ class GemmaModel(GemmaPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
return_legacy_cache
=
False
# noqa: F841
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
return_legacy_cache
=
True
# noqa: F841
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
...
src/transformers/models/llama/modeling_llama.py
View file @
1749841a
...
@@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
...
@@ -360,7 +360,7 @@ class LlamaAttention(nn.Module):
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
)
if
self
.
config
.
pretraining_tp
>
1
:
if
self
.
config
.
pretraining_tp
>
1
:
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
attn_output
=
attn_output
.
split
(
self
.
hidden_size
//
self
.
config
.
pretraining_tp
,
dim
=
2
)
...
@@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -467,7 +467,7 @@ class LlamaFlashAttention2(LlamaAttention):
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
query_states
,
key_states
,
value_states
,
attention_mask
,
q_len
,
dropout
=
dropout_rate
)
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
-
1
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
if
not
output_attentions
:
...
@@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
...
@@ -653,7 +653,7 @@ class LlamaSdpaAttention(LlamaAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
1749841a
...
@@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -620,7 +620,6 @@ class MistralSdpaAttention(MistralAttention):
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_key_value_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
...
@@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
...
@@ -656,7 +655,7 @@ class MistralSdpaAttention(MistralAttention):
)
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
view
(
bsz
,
q_len
,
-
1
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
utils/diff_model_converter.py
View file @
1749841a
...
@@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
...
@@ -238,11 +238,16 @@ class SuperTransformer(cst.CSTTransformer):
Helper method to update the body by removing duplicates before adding new statements.
Helper method to update the body by removing duplicates before adding new statements.
"""
"""
deduplicated_new_body
=
[]
deduplicated_new_body
=
[]
existing_nodes
=
{
existing_nodes
=
set
()
self
.
python_module
.
code_for_node
(
node
).
strip
()
for
node
in
new_statements
if
isinstance
(
node
,
cst
.
CSTNode
)
for
node
in
new_statements
:
}
code
=
self
.
python_module
.
code_for_node
(
node
)
comment_less_code
=
re
.
sub
(
r
"#.*"
,
""
,
code
).
strip
()
comment_less_code
=
re
.
sub
(
r
"\ *\n"
,
"
\n
"
,
comment_less_code
).
strip
()
existing_nodes
.
add
(
comment_less_code
)
for
stmt
in
existing_body
:
for
stmt
in
existing_body
:
if
self
.
python_module
.
code_for_node
(
stmt
).
strip
()
not
in
existing_nodes
:
comment_less_code
=
re
.
sub
(
r
"#.*"
,
""
,
self
.
python_module
.
code_for_node
(
stmt
)).
strip
()
comment_less_code
=
re
.
sub
(
r
"\ *\n"
,
"
\n
"
,
comment_less_code
).
strip
()
if
comment_less_code
not
in
existing_nodes
:
if
m
.
matches
(
stmt
,
DOCSTRING_NODE
)
and
self
.
has_docstring
:
if
m
.
matches
(
stmt
,
DOCSTRING_NODE
)
and
self
.
has_docstring
:
continue
continue
deduplicated_new_body
.
append
(
stmt
)
deduplicated_new_body
.
append
(
stmt
)
...
@@ -542,7 +547,7 @@ if __name__ == "__main__":
...
@@ -542,7 +547,7 @@ if __name__ == "__main__":
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
parser
.
add_argument
(
"--files_to_parse"
,
"--files_to_parse"
,
default
=
[
"
/Users/arthurzucker/Work/transformers/examples/diff-conversion/diff_my_new_model.py
"
],
default
=
[
"
all
"
],
nargs
=
"+"
,
nargs
=
"+"
,
help
=
"A list of `diff_xxxx` files that should be converted to single model file"
,
help
=
"A list of `diff_xxxx` files that should be converted to single model file"
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment