"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b0d539ccad090c8949c1740a9758b4152fad5f72"
Unverified Commit e1ad1886 authored by Peter Lin's avatar Peter Lin Committed by GitHub
Browse files

Fix git model for generate with beam search. (#21071)



* Fix git model for generate with beam search.

* Update comment

* Fix bug on multi batch

* Add generate tests

* Clean up tests

* Fix style
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent e15f0d73
......@@ -1264,6 +1264,11 @@ class GitModel(GitPreTrainedModel):
device=embedding_output.device,
)
# Repeat visual features to match embedding batch size.
projected_visual_features = projected_visual_features.repeat(
embedding_output.size(0) // projected_visual_features.size(0), 1, 1
)
# concatenate patch token and text token embeddings
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
......
......@@ -21,6 +21,7 @@ from transformers import GitConfig, GitProcessor, GitVisionConfig, is_torch_avai
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
......@@ -29,14 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import (
MODEL_FOR_BACKBONE_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_MAPPING,
GitForCausalLM,
GitModel,
GitVisionModel,
)
from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, GitForCausalLM, GitModel, GitVisionModel
from transformers.models.git.modeling_git import GIT_PRETRAINED_MODEL_ARCHIVE_LIST
......@@ -259,13 +253,9 @@ class GitModelTester:
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
token_labels = None
if self.use_labels:
token_labels = ids_tensor([self.batch_size, self.text_seq_length], self.num_labels)
config = self.get_config()
return config, input_ids, input_mask, pixel_values, token_labels
return config, input_ids, input_mask, pixel_values
def get_config(self):
"""
......@@ -292,7 +282,7 @@ class GitModelTester:
pad_token_id=self.pad_token_id,
)
def create_and_check_model(self, config, input_ids, input_mask, pixel_values, token_labels):
def create_and_check_model(self, config, input_ids, input_mask, pixel_values):
model = GitModel(config=config)
model.to(torch_device)
model.eval()
......@@ -310,7 +300,7 @@ class GitModelTester:
result.last_hidden_state.shape, (self.batch_size, self.text_seq_length, self.hidden_size)
)
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values, token_labels):
def create_and_check_for_causal_lm(self, config, input_ids, input_mask, pixel_values):
model = GitForCausalLM(config=config)
model.to(torch_device)
model.eval()
......@@ -331,6 +321,24 @@ class GitModelTester:
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertTrue(result.loss.item() > 0)
def _test_beam_search_generate(self, config, input_ids, input_mask, pixel_values):
model = GitForCausalLM(config=config)
model.to(torch_device)
model.eval()
# generate
generated_ids = model.generate(
input_ids,
attention_mask=input_mask,
pixel_values=pixel_values,
do_sample=False,
max_length=20,
num_beams=2,
num_return_sequences=2,
)
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
......@@ -339,7 +347,6 @@ class GitModelTester:
input_ids,
input_mask,
pixel_values,
token_labels,
) = config_and_inputs
inputs_dict = {
......@@ -352,7 +359,7 @@ class GitModelTester:
@require_torch
class GitModelTest(ModelTesterMixin, unittest.TestCase):
class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (GitModel, GitForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (GitForCausalLM,) if is_torch_available() else ()
......@@ -383,40 +390,19 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_training(self):
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
if model_class in [
*get_values(MODEL_MAPPING),
*get_values(MODEL_FOR_BACKBONE_MAPPING),
]:
continue
print("Model class:", model_class)
def test_beam_search_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_beam_search_generate(*config_and_inputs)
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
for k, v in inputs.items():
print(k, v.shape)
loss = model(**inputs).loss
loss.backward()
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
......@@ -424,6 +410,22 @@ class GitModelTest(ModelTesterMixin, unittest.TestCase):
model = GitModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="GIT has pixel values as additional input")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="GIT has pixel values as additional input")
def test_contrastive_generate(self):
pass
@unittest.skip(reason="GIT has pixel values as additional input")
def test_contrastive_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="GIT has pixel values as additional input")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@require_torch
@require_vision
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment