Unverified Commit 1d4b7978 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: Fix GIT batched captioning (#21738)

parent 78a93d17
......@@ -1217,7 +1217,7 @@ class TFGenerationMixin:
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
......@@ -1225,9 +1225,7 @@ class TFGenerationMixin:
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
......@@ -1235,13 +1233,13 @@ class TFGenerationMixin:
self,
inputs: Optional[tf.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
) -> tf.Tensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.shape[:-1]
......@@ -1250,7 +1248,13 @@ class TFGenerationMixin:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
batch_size = batch_size if batch_size is not None else 1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, tf.Tensor):
batch_size = value.shape[0]
break
return tf.ones((batch_size, 1), dtype=tf.int32) * bos_token_id
@staticmethod
......
......@@ -544,7 +544,7 @@ class GenerationMixin:
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, batch_size=model_kwargs["inputs_embeds"].shape[0]
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
......@@ -552,9 +552,7 @@ class GenerationMixin:
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs.get("encoder_outputs")
)
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
......@@ -567,13 +565,13 @@ class GenerationMixin:
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
encoder_outputs: Optional[ModelOutput] = None,
batch_size: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.LongTensor:
"""Initializes input ids for generation, if necessary."""
if inputs is not None:
return inputs
encoder_outputs = model_kwargs.get("encoder_outputs")
if self.config.is_encoder_decoder and encoder_outputs is not None:
# make dummy input_ids with value -100, as a sanity check ensuring that they won't be used for encoding
shape = encoder_outputs.last_hidden_state.size()[:-1]
......@@ -582,7 +580,13 @@ class GenerationMixin:
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
batch_size = batch_size if batch_size is not None else 1
# If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
# soft-prompting or in multimodal implementations built on top of decoder-only language models.
batch_size = 1
for value in model_kwargs.values():
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def _prepare_attention_mask_for_generation(
......
......@@ -340,6 +340,24 @@ class GitModelTester:
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
def _test_batched_generate_captioning(self, config, input_ids, input_mask, pixel_values):
model = GitForCausalLM(config=config)
model.to(torch_device)
model.eval()
# generate
generated_ids = model.generate(
input_ids=None, # captioning -> no input_ids
attention_mask=None,
pixel_values=pixel_values,
do_sample=False,
max_length=20,
num_beams=2,
num_return_sequences=2,
)
self.parent.assertEqual(generated_ids.shape, (self.batch_size * 2, 20))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
......@@ -398,6 +416,10 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_beam_search_generate(*config_and_inputs)
def test_batched_generate_captioning(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester._test_batched_generate_captioning(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment