Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e11d72d
Unverified
Commit
5e11d72d
authored
Sep 28, 2023
by
Marc Sun
Committed by
GitHub
Sep 28, 2023
Browse files
fix_mbart_tied_weights (#26422)
* fix_mbart_tied_weights * add test
parent
216dff75
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
0 deletions
+42
-0
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mbart/modeling_mbart.py
+5
-0
tests/models/mbart/test_modeling_mbart.py
tests/models/mbart/test_modeling_mbart.py
+37
-0
No files found.
src/transformers/models/mbart/modeling_mbart.py
View file @
5e11d72d
...
...
@@ -1184,6 +1184,11 @@ class MBartModel(MBartPreTrainedModel):
def
get_decoder
(
self
):
return
self
.
decoder
def
_tie_weights
(
self
):
if
self
.
config
.
tie_word_embeddings
:
self
.
_tie_or_clone_weights
(
self
.
encoder
.
embed_tokens
,
self
.
get_input_embeddings
())
self
.
_tie_or_clone_weights
(
self
.
decoder
.
embed_tokens
,
self
.
get_input_embeddings
())
@
add_start_docstrings_to_model_forward
(
MBART_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
checkpoint
=
_CHECKPOINT_FOR_DOC
,
...
...
tests/models/mbart/test_modeling_mbart.py
View file @
5e11d72d
...
...
@@ -327,6 +327,43 @@ class MBartModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
)
model
.
generate
(
num_beams
=
4
,
do_sample
=
True
,
early_stopping
=
False
,
num_return_sequences
=
3
)
def
test_ensure_weights_are_shared
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs
()
config
.
tie_word_embeddings
=
True
model
=
MBartForConditionalGeneration
(
config
)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
embed_tokens
.
weight
.
data_ptr
(),
model
.
base_model
.
encoder
.
embed_tokens
.
weight
.
data_ptr
(),
}
),
1
,
)
config
.
tie_word_embeddings
=
False
model
=
MBartForConditionalGeneration
(
config
)
# MBart shares four weights.
# Not an issue to not have these correctly tied for torch.load, but it is an issue for safetensors.
self
.
assertEqual
(
len
(
{
model
.
get_output_embeddings
().
weight
.
data_ptr
(),
model
.
get_input_embeddings
().
weight
.
data_ptr
(),
model
.
base_model
.
decoder
.
embed_tokens
.
weight
.
data_ptr
(),
model
.
base_model
.
encoder
.
embed_tokens
.
weight
.
data_ptr
(),
}
),
2
,
)
def
assert_tensors_close
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment