Unverified Commit 39371ee4 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[Bart/Memory] don't create lm_head (#3323)

* delete lm_head, skips weight tying
* Fixed s3
parent 5ad2ea06
......@@ -804,13 +804,8 @@ class BartForConditionalGeneration(PretrainedBartModel):
def __init__(self, config: BartConfig):
super().__init__(config)
# if base_model is None:
base_model = BartModel(config)
self.model = base_model
self.lm_head = _make_linear_from_emb(self.model.shared)
def tie_weights(self):
pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same.
@add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING)
def forward(
......@@ -875,7 +870,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_cached_states=decoder_cached_states,
generation_mode=generation_mode,
)
lm_logits = self.lm_head(outputs[0])
lm_logits = F.linear(outputs[0], self.model.shared.weight)
outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here
if lm_labels is not None:
loss_fct = nn.CrossEntropyLoss()
......@@ -932,7 +927,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return self.model.encoder
def get_output_embeddings(self):
return self.lm_head
return _make_linear_from_emb(self.model.shared) # make it on the fly
@add_start_docstrings(
......
......@@ -113,7 +113,8 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_head_masking = False
test_resize_embeddings = False # This requires inputs_dict['input_ids']
test_resize_embeddings = True # This requires inputs_dict['input_ids']
test_missing_keys = False # because BartForConditionalGeneration and BartModel now have identical state_dict
def setUp(self):
self.model_tester = ModelTester(self)
......@@ -371,6 +372,22 @@ class BartHeadTests(unittest.TestCase):
)
self.assertTrue(torch.eq(decoder_attn_mask_no_padding_no_causal_mask, 0).all())
def test_resize_tokens_embeddings_more(self):
config, input_ids, _ = self._get_config_and_data()
def _get_embs(m):
return (m.get_input_embeddings().weight.data.clone(), m.get_output_embeddings().weight.data.clone())
model = BartForConditionalGeneration(config).eval().to(torch_device)
input, output = _get_embs(model)
self.assertTrue(torch.eq(input, output).all())
new_vocab_size = 45
model.resize_token_embeddings(new_vocab_size)
input_new, output_new = _get_embs(model)
self.assertEqual(input_new.shape, (new_vocab_size, config.d_model))
self.assertEqual(output_new.shape, (new_vocab_size, config.d_model))
self.assertTrue(torch.eq(input_new, output_new).all())
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
......@@ -58,6 +58,7 @@ class ModelTesterMixin:
test_pruning = True
test_resize_embeddings = True
test_head_masking = True
test_missing_keys = True
is_encoder_decoder = False
def test_save_load(self):
......@@ -527,6 +528,8 @@ class ModelTesterMixin:
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
def test_correct_missing_keys(self):
if not self.test_missing_keys:
return
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment