Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
26ba56cc
Unverified
Commit
26ba56cc
authored
Sep 21, 2023
by
Lysandre Debut
Committed by
GitHub
Sep 21, 2023
Browse files
Fix FSMT weight sharing (#26292)
parent
da971b22
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
2 deletions
+21
-2
src/transformers/models/fsmt/modeling_fsmt.py
src/transformers/models/fsmt/modeling_fsmt.py
+3
-2
tests/models/fsmt/test_modeling_fsmt.py
tests/models/fsmt/test_modeling_fsmt.py
+18
-0
No files found.
src/transformers/models/fsmt/modeling_fsmt.py
View file @
26ba56cc
...
@@ -1056,8 +1056,9 @@ class FSMTModel(PretrainedFSMTModel):
...
@@ -1056,8 +1056,9 @@ class FSMTModel(PretrainedFSMTModel):
return
self
.
decoder
return
self
.
decoder
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
get_input_embeddings
())
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
decoder
.
output_projection
,
self
.
get_input_embeddings
())
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
get_input_embeddings
())
self
.
_tie_or_clone_weights
(
self
.
decoder
.
output_projection
,
self
.
get_input_embeddings
())
@
add_start_docstrings_to_model_forward
(
FSMT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
FSMT_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
...
...
tests/models/fsmt/test_modeling_fsmt.py
View file @
26ba56cc
...
@@ -273,6 +273,8 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -273,6 +273,8 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def
test_ensure_weights_are_shared
(
self
):
def
test_ensure_weights_are_shared
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs
()
config
.
tie_word_embeddings
=
True
model
=
FSMTForConditionalGeneration
(
config
)
model
=
FSMTForConditionalGeneration
(
config
)
# FSMT shares three weights.
# FSMT shares three weights.
...
@@ -288,6 +290,22 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -288,6 +290,22 @@ class FSMTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
1
,
1
,
)
)
config
.
tie_word_embeddings
=
False
model
=
FSMTForConditionalGeneration
(
config
)
# FSMT shares three weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
output_projection
.
weight
.
data_ptr
(),
}
),
2
,
)
@
unittest
.
skip
(
"can't be implemented for FSMT due to dual vocab."
)
@
unittest
.
skip
(
"can't be implemented for FSMT due to dual vocab."
)
def
test_resize_tokens_embeddings
(
self
):
def
test_resize_tokens_embeddings
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment