Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e11d72d
Unverified
Commit
5e11d72d
authored
Sep 28, 2023
by
Marc Sun
Committed by
GitHub
Sep 28, 2023
Browse files
fix_mbart_tied_weights (#26422)
* fix_mbart_tied_weights * add test
parent
216dff75
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
0 deletions
+42
-0
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mbart/modeling_mbart.py
+5
-0
tests/models/mbart/test_modeling_mbart.py
tests/models/mbart/test_modeling_mbart.py
+37
-0
No files found.
src/transformers/models/mbart/modeling_mbart.py
View file @
5e11d72d
...
...
@@ -1184,6 +1184,11 @@ class MBartModel(MBartPreTrainedModel):
def
get_decoder
(
self
):
return
self
.
decoder
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
get_input_embeddings
())
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
get_input_embeddings
())
@
add_start_docstrings_to_model_forward
(
MBART_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
checkpoint
=
_CHECKPOINT_FOR_DOC
,
...
...
tests/models/mbart/test_modeling_mbart.py
View file @
5e11d72d
...
...
@@ -327,6 +327,43 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
)
model
.
generate
(
num_beams
=
4
,
do_sample
=
True
,
early_stopping
=
False
,
num_return_sequences
=
3
)
def
test_ensure_weights_are_shared
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs
()
config
.
tie_word_embeddings
=
True
model
=
MBartForConditionalGeneration
(
config
)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
embed_tokens
.
weight
.
data_ptr
(),
model
.
base_model
.
encoder
.
embed_tokens
.
weight
.
data_ptr
(),
}
),
1
,
)
config
.
tie_word_embeddings
=
False
model
=
MBartForConditionalGeneration
(
config
)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
embed_tokens
.
weight
.
data_ptr
(),
model
.
base_model
.
encoder
.
embed_tokens
.
weight
.
data_ptr
(),
}
),
2
,
)
def
assert_tensors_close
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment