Unverified Commit 8eb38f63 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link
parent 95e70575
......@@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
The original code can be found [here](https://github.com/google-research/pix2struct).
## Resources:
## Resources
- [Paper](https://arxiv.org/abs/2210.03347)
- [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
- [All models](https://huggingface.co/models?search=pix2struct)
......@@ -70,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str
## Pix2StructForConditionalGeneration
[[autodoc]] Pix2StructForConditionalGeneration
- forward
- forward
\ No newline at end of file
......@@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin):
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
for decoder_name in ("decoder", "generator"):
for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict:
default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name]
......
......@@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_range=0.02,
is_vqa=False,
tie_word_embeddings=False,
is_encoder_decoder=True,
**kwargs,
):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs)
if text_config is None:
text_config = {}
......@@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig):
self.text_config = Pix2StructTextConfig(**text_config)
self.vision_config = Pix2StructVisionConfig(**vision_config)
self.text_config.encoder_hidden_size = self.vision_config.hidden_size
self.decoder_start_token_id = self.text_config.decoder_start_token_id
self.pad_token_id = self.text_config.pad_token_id
self.eos_token_id = self.text_config.eos_token_id
self.initializer_factor = initializer_factor
self.initializer_range = initializer_range
......
......@@ -14,7 +14,6 @@
# limitations under the License.
""" Pix2Struct modeling file"""
import copy
import math
from typing import Dict, List, Optional, Tuple, Union
......@@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
cross_attentions=all_cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
@add_start_docstrings(
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
......@@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
def __init__(self, config: Pix2StructConfig):
super().__init__(config)
encoder_config = copy.deepcopy(config.vision_config)
self.encoder = Pix2StructVisionModel(encoder_config)
decoder_config = copy.deepcopy(config.text_config)
self.decoder_start_token_id = decoder_config.pad_token_id
self.decoder_eos_token_ids = decoder_config.eos_token_id
self.decoder = Pix2StructTextModel(decoder_config)
self.encoder = Pix2StructVisionModel(config.vision_config)
self.decoder = Pix2StructTextModel(config.text_config)
self.is_vqa = config.is_vqa
......@@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
Example:
Inference:
```python
>>> from PIL import Image
>>> import requests
......@@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> labels = "A stop sign is on the street corner."
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=labels, return_tensors="pt", add_special_tokens=True)
>>> inputs = processor(images=image, return_tensors="pt")
>>> # autoregressive generation
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
A stop sign is on a street corner.
```
Training:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "A stop sign is on the street corner."
>>> inputs = processor(images=image, return_tensors="pt")
>>> labels = processor(text=text, return_tensors="pt").input_ids
>>> # forward pass
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.loss
>>> outputs = model(**inputs, labels=labels)
>>> loss = outputs.loss
>>> print(loss.item())
5.239729881286621
```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......@@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
@torch.no_grad()
def generate(
def prepare_inputs_for_generation(
self,
flattened_patches: torch.FloatTensor,
decoder_input_ids: Optional[torch.LongTensor] = None,
input_ids,
flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
r"""
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> conditional_text = "A stop sign"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=conditional_text, return_tensors="pt", add_special_tokens=True)
>>> # forward pass
>>> outputs = model.generate(**inputs)
>>> print(processor.batch_decode(outputs, skip_special_tokens=True))
['A stop sign the street with a sign that says yes']
```"""
batch_size, _, _ = flattened_patches.shape
vision_outputs = self.encoder(flattened_patches=flattened_patches, attention_mask=attention_mask)
image_embeds = vision_outputs[0]
if isinstance(decoder_input_ids, torch.Tensor):
# check if the first element of `input_ids` is equal to `decoder_input_ids`:
if (decoder_input_ids[:, 0] != self.decoder_start_token_id).all().item():
# add `decoder_input_ids` as first token to `input_ids`
decoder_input_ids = torch.cat(
if isinstance(input_ids, torch.Tensor):
# check if the first element of `input_ids` is equal to `input_ids`:
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
# add `input_ids` as first token to `input_ids`
input_ids = torch.cat(
[
torch.ones((decoder_input_ids.shape[0], 1), dtype=torch.long, device=decoder_input_ids.device)
* self.decoder_start_token_id,
decoder_input_ids,
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
* self.config.decoder_start_token_id,
input_ids,
],
dim=-1,
)
......@@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
],
dim=-1,
)
elif decoder_input_ids is None:
decoder_input_ids = (
torch.LongTensor([[self.decoder_start_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
)
elif input_ids is None:
batch_size = flattened_patches.shape[0]
input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(image_embeds.device)
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
outputs = self.decoder.generate(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=attention_mask,
**generate_kwargs,
)
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return outputs
return {
"flattened_patches": flattened_patches,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
......@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = (
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"]
)
self.assertListEqual(arg_names[:1], expected_arg_names)
expected_arg_names = [
"flattened_patches",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"labels",
"decoder_inputs_embeds",
"use_cache",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_training(self):
if not self.model_tester.is_training:
......@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
)
def test_vqa_model(self):
model_id = "ybelkada/pix2struct-ai2d-base"
model_id = "google/pix2struct-ai2d-base"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
......@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
def test_vqa_model_batched(self):
model_id = "ybelkada/pix2struct-ai2d-base"
model_id = "google/pix2struct-ai2d-base"
image_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
......
......@@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py
src/transformers/models/pegasus/tokenization_pegasus_fast.py
src/transformers/models/perceiver/tokenization_perceiver.py
src/transformers/models/phobert/tokenization_phobert.py
src/transformers/models/pix2struct/modeling_pix2struct.py
src/transformers/models/plbart/tokenization_plbart.py
src/transformers/models/prophetnet/tokenization_prophetnet.py
src/transformers/models/rag/tokenization_rag.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment