Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
77ed9fa1
Unverified
Commit
77ed9fa1
authored
Sep 18, 2023
by
Lysandre Debut
Committed by
GitHub
Sep 18, 2023
Browse files
[FSMT] Fix non-shared weights (#26187)
* Fix non-shared weights * Add tests * Edit tied weights keys
parent
f0a6057f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
2 deletions
+23
-2
src/transformers/models/fsmt/modeling_fsmt.py
src/transformers/models/fsmt/modeling_fsmt.py
+6
-2
tests/models/fsmt/test_modeling_fsmt.py
tests/models/fsmt/test_modeling_fsmt.py
+17
-0
No files found.
src/transformers/models/fsmt/modeling_fsmt.py
View file @
77ed9fa1
...
...
@@ -1034,7 +1034,7 @@ def _get_shape(t):
FSMT_START_DOCSTRING
,
)
class
FSMTModel
(
PretrainedFSMTModel
):
_tied_weights_keys
=
[
"decoder.embed_tokens.weight"
]
_tied_weights_keys
=
[
"decoder.embed_tokens.weight"
,
"decoder.output_projection.weight"
]
def
__init__
(
self
,
config
:
FSMTConfig
):
super
().
__init__
(
config
)
...
...
@@ -1055,6 +1055,10 @@ class FSMTModel(PretrainedFSMTModel):
def
get_decoder
(
self
):
return
self
.
decoder
def
_tie_weights
(
self
):
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
get_input_embeddings
())
self
.
_tie_or_clone_weights
(
self
.
decoder
.
output_projection
,
self
.
get_input_embeddings
())
@
add_start_docstrings_to_model_forward
(
FSMT_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
checkpoint
=
_CHECKPOINT_FOR_DOC
,
...
...
@@ -1171,7 +1175,7 @@ class FSMTModel(PretrainedFSMTModel):
)
class
FSMTForConditionalGeneration
(
PretrainedFSMTModel
):
base_model_prefix
=
"model"
_tied_weights_keys
=
[
"
model.
decoder.embed_tokens.weight"
]
_tied_weights_keys
=
[
"decoder.embed_tokens.weight"
,
"decoder.output_projection.weight"
]
def
__init__
(
self
,
config
:
FSMTConfig
):
super
().
__init__
(
config
)
...
...
tests/models/fsmt/test_modeling_fsmt.py
View file @
77ed9fa1
...
...
@@ -271,6 +271,23 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
input_names
=
[
"input_ids"
,
"attention_mask"
],
)
def
test_ensure_weights_are_shared
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs
()
model
=
FSMTForConditionalGeneration
(
config
)
# FSMT shares three weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
output_projection
.
weight
.
data_ptr
(),
}
),
1
,
)
@
unittest
.
skip
(
"can't be implemented for FSMT due to dual vocab."
)
def
test_resize_tokens_embeddings
(
self
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment