Unverified Commit 73f0a5d1 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fixes Loss for TransfoXL when using Trainer API v2 (#16140)



* fix(transfo_xl): Fixes TransfoXL support when using Trainer.

* fix(tests): Uses losses_1 and losses_2 pattern with TransfoXL test.

* fix(transfo_xl): Adds requested changes to allow for backward compatibility.

fix(transfo_xl): Adds requested changes to allow for backward compatibility.

fix(transfo_xl): Fixes code styling.

* Backward compatibility

* Update src/transformers/models/transfo_xl/modeling_transfo_xl.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarGustavo de Rosa <gth.rosa@uol.com.br>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 76c74b37
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular PyTorch Transformer XL model. Adapted from https://github.com/kimiyoung/transformer-xl. In particular
https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
""" """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -692,6 +693,8 @@ class TransfoXLLMHeadModelOutput(ModelOutput): ...@@ -692,6 +693,8 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads. heads.
loss (`torch.FloatTensor` of shape `()`, *optional*, returned when `labels` is provided)
Reduced language modeling loss.
""" """
losses: Optional[torch.FloatTensor] = None losses: Optional[torch.FloatTensor] = None
...@@ -699,6 +702,7 @@ class TransfoXLLMHeadModelOutput(ModelOutput): ...@@ -699,6 +702,7 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
mems: List[torch.FloatTensor] = None mems: List[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
loss: Optional[torch.FloatTensor] = None
@property @property
def logits(self): def logits(self):
...@@ -1011,6 +1015,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1011,6 +1015,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
super().__init__(config) super().__init__(config)
self.transformer = TransfoXLModel(config) self.transformer = TransfoXLModel(config)
self.sample_softmax = config.sample_softmax self.sample_softmax = config.sample_softmax
self.trainer_compatible = getattr(config, "trainer_compatible", False)
if not self.trainer_compatible:
warnings.warn(
"The output of TransfoXL will be updated in v5 to support a single loss as first argument. In order"
"to use that updated output, please specify `trainer_compatible=True` as your configuration attribute.",
DeprecationWarning,
)
assert ( assert (
self.sample_softmax <= 0 self.sample_softmax <= 0
...@@ -1095,17 +1107,38 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1095,17 +1107,38 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
last_hidden = transformer_outputs[0] last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
if labels is not None:
# Prevents all labels being -100 and throwing an error
# when backwarding the loss
miss_valid_label = labels[0, 1:].sum() == (labels.size(1) - 1) * -100
if miss_valid_label:
# Sets an <EOS> token, just to prevent loss from being NaN
labels[0, 1] = self.config.eos_token_id
softmax_output = self.crit(pred_hid, labels) softmax_output = self.crit(pred_hid, labels)
prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else () prediction_scores = softmax_output.view(bsz, tgt_len, -1) if labels is None else ()
loss = softmax_output.view(bsz, tgt_len - 1) if labels is not None else None
if labels is not None:
losses = softmax_output.view(bsz, tgt_len - 1)
# Avoids from incorporating padding (-100) tokens into loss value
loss = losses[losses != 0].mean()
else:
losses, loss = None, None
if not return_dict: if not return_dict:
output = (prediction_scores,) + transformer_outputs[1:] if self.trainer_compatible:
output = (prediction_scores, losses) if losses is not None else (prediction_scores,)
output += transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
else:
output = (prediction_scores, *transformer_outputs[1:])
output = ((losses,) + output) if losses is not None else output
return (output + (loss,)) if loss is not None else output
return TransfoXLLMHeadModelOutput( return TransfoXLLMHeadModelOutput(
losses=loss, loss=loss,
prediction_scores=prediction_scores, prediction_scores=prediction_scores,
losses=losses,
mems=transformer_outputs.mems, mems=transformer_outputs.mems,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
......
...@@ -132,30 +132,90 @@ class TransfoXLModelTester: ...@@ -132,30 +132,90 @@ class TransfoXLModelTester:
outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"]) outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])
outputs = { outputs = {
"loss_1": outputs1["losses"], "loss_1": outputs1["loss"],
"losses_1": outputs1["losses"],
"mems_1": outputs1["mems"], "mems_1": outputs1["mems"],
"lm_logits_1": lm_logits_1, "lm_logits_1": lm_logits_1,
"loss_2": outputs2["losses"], "loss_2": outputs2["loss"],
"losses_2": outputs2["losses"],
"mems_2": outputs2["mems"], "mems_2": outputs2["mems"],
"lm_logits_2": lm_logits_2, "lm_logits_2": lm_logits_2,
} }
return outputs return outputs
def check_transfo_xl_lm_head_output(self, result): def check_transfo_xl_lm_head_output(self, result):
self.parent.assertEqual(result["loss_1"].shape, (self.batch_size, self.seq_length - 1)) self.parent.assertEqual(result["loss_1"].shape, ())
self.parent.assertEqual(result["losses_1"].shape, (self.batch_size, self.seq_length - 1))
self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result["lm_logits_1"].shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual( self.parent.assertListEqual(
[mem.shape for mem in result["mems_1"]], [mem.shape for mem in result["mems_1"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
self.parent.assertEqual(result["loss_2"].shape, (self.batch_size, self.seq_length - 1)) self.parent.assertEqual(result["loss_2"].shape, ())
self.parent.assertEqual(result["losses_2"].shape, (self.batch_size, self.seq_length - 1))
self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result["lm_logits_2"].shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertListEqual( self.parent.assertListEqual(
[mem.shape for mem in result["mems_2"]], [mem.shape for mem in result["mems_2"]],
[(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers, [(self.mem_len, self.batch_size, self.hidden_size)] * self.num_hidden_layers,
) )
def create_transfo_xl_lm_head_trainer_compatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
config.trainer_compatible = True
model = TransfoXLLMHeadModel(config)
model.to(torch_device)
model.eval()
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
loss_1, _, losses_1, mems_1 = outputs1[:4]
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1, return_dict=False)
loss_2, _, losses_2, mems_2 = outputs2[:4]
outputs = {
"losses_1": losses_1,
"mems_1": mems_1,
"lm_logits_1": lm_logits_1,
"loss_1": loss_1,
"losses_2": losses_2,
"mems_2": mems_2,
"lm_logits_2": lm_logits_2,
"loss_2": loss_2,
}
config.trainer_compatible = None
return outputs
def create_transfo_xl_lm_head_trainer_incompatible_tuple(self, config, input_ids_1, input_ids_2, lm_labels):
config.trainer_compatible = False
model = TransfoXLLMHeadModel(config)
model.to(torch_device)
model.eval()
lm_logits_1 = model(input_ids_1, return_dict=False)[0]
outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
losses_1, _, mems_1 = outputs1[:3]
loss_1 = outputs1[-1]
lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
losses_2, _, mems_2 = outputs2[:3]
loss_2 = outputs2[-1]
outputs = {
"losses_1": losses_1,
"mems_1": mems_1,
"lm_logits_1": lm_logits_1,
"loss_1": loss_1,
"losses_2": losses_2,
"mems_2": mems_2,
"lm_logits_2": lm_logits_2,
"loss_2": loss_2,
}
config.trainer_compatible = None
return outputs
def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels): def create_and_check_transfo_xl_for_sequence_classification(self, config, input_ids_1, input_ids_2, lm_labels):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = TransfoXLForSequenceClassification(config) model = TransfoXLForSequenceClassification(config)
...@@ -220,9 +280,16 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -220,9 +280,16 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
def test_transfo_xl_lm_head(self): def test_transfo_xl_lm_head(self):
self.model_tester.set_seed() self.model_tester.set_seed()
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result) self.model_tester.check_transfo_xl_lm_head_output(output_result)
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_compatible_tuple(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result)
output_result = self.model_tester.create_transfo_xl_lm_head_trainer_incompatible_tuple(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result)
def test_transfo_xl_sequence_classification_model(self): def test_transfo_xl_sequence_classification_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs) self.model_tester.create_and_check_transfo_xl_for_sequence_classification(*config_and_inputs)
...@@ -232,10 +299,8 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC ...@@ -232,10 +299,8 @@ class TransfoXLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestC
return return
@require_torch_multi_gpu @require_torch_multi_gpu
@unittest.skip(
reason="Transfo-XL does not work with data parallel (DP) because of a bug in PyTorch: https://github.com/pytorch/pytorch/issues/36035"
)
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
# Opt-out of this test.
pass pass
@slow @slow
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment