Unverified Commit 9dfd6a4b authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: handle text conditioning with multimodal encoder-decoder models (#22748)

parent 90ce374d
......@@ -837,12 +837,12 @@ class TFGenerationMixin:
# 6. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
......@@ -1095,16 +1095,41 @@ class TFGenerationMixin:
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, tf.Tensor],
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, tf.Tensor]] = None,
) -> tf.Tensor:
# prepare `input_ids` for decoder if model is encoder-decoder
) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
decoder_input_ids_start = tf.ones((batch_size, 1), dtype=tf.int32) * decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif tf.reduce_all(decoder_input_ids[:, 0] != decoder_start_token_id):
decoder_input_ids = tf.concat([decoder_input_ids_start, decoder_input_ids], axis=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = tf.concat(
(tf.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
axis=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
# retrieve decoder_start_token_id for encoder-decoder models
......
......@@ -642,18 +642,44 @@ class GenerationMixin:
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
model_input_name: str,
model_kwargs: Dict[str, torch.Tensor],
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
device: torch.device = None,
) -> torch.LongTensor:
) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
"""Prepares `decoder_input_ids` for generation with encoder-decoder models"""
# 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
# we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
decoder_input_ids = model_kwargs.pop("decoder_input_ids")
elif "input_ids" in model_kwargs and model_input_name != "input_ids":
decoder_input_ids = model_kwargs.pop("input_ids")
else:
decoder_input_ids = None
# 2. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
if device is None:
device = self.device
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
# no user input -> use decoder_start_token_id as decoder_input_ids
if decoder_input_ids is None:
decoder_input_ids = decoder_input_ids_start
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
decoder_input_ids = torch.cat([decoder_input_ids_start, decoder_input_ids], dim=-1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
decoder_attention_mask = torch.cat(
(torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
dim=-1,
)
model_kwargs["decoder_attention_mask"] = decoder_attention_mask
return decoder_input_ids, model_kwargs
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
......@@ -1289,17 +1315,14 @@ class GenerationMixin:
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
device=inputs_tensor.device,
)
# conditional generation for multi-modal models.
if "input_ids" in model_kwargs and model_input_name == "pixel_values":
input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
......
......@@ -1776,35 +1776,6 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_outputs=None,
**kwargs,
):
if isinstance(input_ids, torch.Tensor):
# check if the first element of `input_ids` is equal to `input_ids`:
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
# add `input_ids` as first token to `input_ids`
input_ids = torch.cat(
[
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
* self.config.decoder_start_token_id,
input_ids,
],
dim=-1,
)
if decoder_attention_mask is not None:
decoder_attention_mask = torch.cat(
[
torch.ones(
(decoder_attention_mask.shape[0], 1),
dtype=torch.long,
device=decoder_attention_mask.device,
),
decoder_attention_mask,
],
dim=-1,
)
elif input_ids is None:
batch_size = flattened_patches.shape[0]
input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
......
......@@ -94,8 +94,8 @@ class GenerationIntegrationTestsMixin:
# Decoder only call
outputs = bart_model.generate(decoder_input_ids=input_ids, max_new_tokens=max_new_tokens)
# 29 + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 32])
# 1 BOS + 29 (input length) + 3 new tokens
self.assertEqual(list(outputs.shape), [1, 33])
# Encoder decoder call > 20
outputs = bart_model.generate(max_new_tokens=max_new_tokens + 20)
......@@ -658,3 +658,31 @@ class GenerationIntegrationTestsMixin:
[token == model.config.pad_token_id for token in generated_tokens[0][expectation:]]
)
self.assertTrue(unpadded_correct_condition or padded_correct_condition)
def test_generate_vision2text_conditioning(self):
model_cls = self.framework_dependent_parameters["AutoModelForVision2Seq"]
floats_tensor = self.framework_dependent_parameters["floats_tensor"]
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
is_pt = not model_cls.__name__.startswith("TF")
pixel_values = floats_tensor((2, 3, 30, 30))
conditioning_input = create_tensor_fn([[10], [10]]) # this should be the 2nd output token, after the BOS token
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2")
if is_pt:
pixel_values = pixel_values.to(torch_device)
model = model.to(torch_device)
conditioning_input = conditioning_input.to(torch_device)
# we can condition on decoder_input_ids (expected decoder input) and input_ids (which we pipe internally as
# decoder_input_ids, if the encoder is not a model with text input)
output_sequences_decoder_input_ids = model.generate(
pixel_values, max_length=5, decoder_input_ids=conditioning_input
)
output_sequences_input_ids = model.generate(pixel_values, max_length=5, input_ids=conditioning_input)
if is_pt:
output_sequences_decoder_input_ids = output_sequences_decoder_input_ids.cpu().numpy()
output_sequences_input_ids = output_sequences_input_ids.cpu().numpy()
conditioning_input = conditioning_input.cpu().numpy()
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))
......@@ -1892,8 +1892,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1919,8 +1921,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
max_length = 20
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1949,8 +1953,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(2, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -1982,8 +1988,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......@@ -2021,8 +2029,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
# Greedy
input_ids = input_ids.expand(6, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
input_ids = bart_model._prepare_decoder_input_ids_for_generation(
input_ids.shape[0],
input_ids, model_kwargs = bart_model._prepare_decoder_input_ids_for_generation(
batch_size=input_ids.shape[0],
model_input_name=bart_model.main_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=bart_model.config.decoder_start_token_id,
bos_token_id=bart_model.config.bos_token_id,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment