Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
39371ee4
Unverified
Commit
39371ee4
authored
Mar 26, 2020
by
Sam Shleifer
Committed by
GitHub
Mar 26, 2020
Browse files
[Bart/Memory] don't create lm_head (#3323)
* delete lm_head, skips weight tying * Fixed s3
parent
5ad2ea06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
8 deletions
+23
-8
src/transformers/modeling_bart.py
src/transformers/modeling_bart.py
+2
-7
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+18
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+3
-0
No files found.
src/transformers/modeling_bart.py
View file @
39371ee4
...
...
@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def
__init__
(
self
,
config
:
BartConfig
):
super
().
__init__
(
config
)
# if base_model is None:
base_model
=
BartModel
(
config
)
self
.
model
=
base_model
self
.
lm_head
=
_make_linear_from_emb
(
self
.
model
.
shared
)
def
tie_weights
(
self
):
pass
# hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@
add_start_docstrings_to_callable
(
BART_INPUTS_DOCSTRING
)
def
forward
(
...
...
@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states
=
decoder_cached_states
,
generation_mode
=
generation_mode
,
)
lm_logits
=
self
.
lm_h
ea
d
(
outputs
[
0
])
lm_logits
=
F
.
lin
ea
r
(
outputs
[
0
]
,
self
.
model
.
shared
.
weight
)
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
if
lm_labels
is
not
None
:
loss_fct
=
nn
.
CrossEntropyLoss
()
...
...
@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return
self
.
model
.
encoder
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
_make_linear_from_emb
(
self
.
model
.
shared
)
# make it on the fly
@
add_start_docstrings
(
...
...
tests/test_modeling_bart.py
View file @
39371ee4
...
...
@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning
=
False
test_torchscript
=
False
test_head_masking
=
False
test_resize_embeddings
=
False
# This requires inputs_dict['input_ids']
test_resize_embeddings
=
True
# This requires inputs_dict['input_ids']
test_missing_keys
=
False
# because BartForConditionalGeneration and BartModel now have identical state_dict
def
setUp
(
self
):
self
.
model_tester
=
ModelTester
(
self
)
...
...
@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
)
self
.
assertTrue
(
torch
.
eq
(
decoder_attn_mask_no_padding_no_causal_mask
,
0
).
all
())
def
test_resize_tokens_embeddings_more
(
self
):
config
,
input_ids
,
_
=
self
.
_get_config_and_data
()
def
_get_embs
(
m
):
return
(
m
.
get_input_embeddings
().
weight
.
data
.
clone
(),
m
.
get_output_embeddings
().
weight
.
data
.
clone
())
model
=
BartForConditionalGeneration
(
config
).
eval
().
to
(
torch_device
)
input
,
output
=
_get_embs
(
model
)
self
.
assertTrue
(
torch
.
eq
(
input
,
output
).
all
())
new_vocab_size
=
45
model
.
resize_token_embeddings
(
new_vocab_size
)
input_new
,
output_new
=
_get_embs
(
model
)
self
.
assertEqual
(
input_new
.
shape
,
(
new_vocab_size
,
config
.
d_model
))
self
.
assertEqual
(
output_new
.
shape
,
(
new_vocab_size
,
config
.
d_model
))
self
.
assertTrue
(
torch
.
eq
(
input_new
,
output_new
).
all
())
def
_assert_tensors_equal
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
...
...
tests/test_modeling_common.py
View file @
39371ee4
...
...
@@ -58,6 +58,7 @@ class ModelTesterMixin:
test_pruning
=
True
test_resize_embeddings
=
True
test_head_masking
=
True
test_missing_keys
=
True
is_encoder_decoder
=
False
def
test_save_load
(
self
):
...
...
@@ -527,6 +528,8 @@ class ModelTesterMixin:
self
.
assertTrue
(
x
is
None
or
isinstance
(
x
,
torch
.
nn
.
Linear
))
def
test_correct_missing_keys
(
self
):
if
not
self
.
test_missing_keys
:
return
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment