Unverified Commit 9edf3758 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[GIT] Fix training (#21133)



* Fix training

* Add test

* Fix failing tests
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 0fb27dc9
...@@ -1487,20 +1487,21 @@ class GitForCausalLM(GitPreTrainedModel): ...@@ -1487,20 +1487,21 @@ class GitForCausalLM(GitPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
logits = self.output(sequence_output) logits = self.output(sequence_output)
lm_loss = None loss = None
if labels is not None: if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one # we are doing next-token prediction; shift prediction scores and input ids by one
shifted_logits = logits[:, :-1, :].contiguous() num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
labels = labels[:, 1:].contiguous() labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1)) loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast( return CausalLMOutputWithPast(
loss=lm_loss, loss=loss,
logits=logits, logits=logits,
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
......
...@@ -29,7 +29,14 @@ if is_torch_available(): ...@@ -29,7 +29,14 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import MODEL_FOR_PRETRAINING_MAPPING, GitForCausalLM, GitModel, GitVisionModel from transformers import (
MODEL_FOR_BACKBONE_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_MAPPING,
GitForCausalLM,
GitModel,
GitVisionModel,
)
from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST
...@@ -317,10 +324,12 @@ class GitModelTester: ...@@ -317,10 +324,12 @@ class GitModelTester:
result = model(input_ids) result = model(input_ids)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.text_seq_length, self.vocab_size))
# TODO training # training
# result = model(input_ids, attention_mask=input_mask, pixel_values=pixel_values) result = model(input_ids, attention_mask=input_mask, pixel_values=pixel_values, labels=input_ids)
# self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
# self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertTrue(result.loss.item() > 0)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
...@@ -350,17 +359,16 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -350,17 +359,16 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
fx_compatible = False fx_compatible = False
test_torchscript = False test_torchscript = False
# special case for ForPreTraining model # special case for GitForCausalLM model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING): if model_class in get_values(MODEL_FOR_CAUSAL_LM_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.text_seq_length),
) dtype=torch.long,
inputs_dict["next_sentence_label"] = torch.zeros( device=torch_device,
self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
return inputs_dict return inputs_dict
...@@ -385,6 +393,31 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -385,6 +393,31 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_training(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in [
*get_values(MODEL_MAPPING),
*get_values(MODEL_FOR_BACKBONE_MAPPING),
]:
continue
print("Model class:", model_class)
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
for k, v in inputs.items():
print(k, v.shape)
loss = model(**inputs).loss
loss.backward()
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in GIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment