Unverified Commit 8eb38f63 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link
parent 95e70575
...@@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu ...@@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu
This model was contributed by [ybelkada](https://huggingface.co/ybelkada). This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
The original code can be found [here](https://github.com/google-research/pix2struct). The original code can be found [here](https://github.com/google-research/pix2struct).
## Resources: ## Resources
- [Paper](https://arxiv.org/abs/2210.03347)
- [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb) - [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
- [All models](https://huggingface.co/models?search=pix2struct) - [All models](https://huggingface.co/models?search=pix2struct)
...@@ -70,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str ...@@ -70,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str
## Pix2StructForConditionalGeneration ## Pix2StructForConditionalGeneration
[[autodoc]] Pix2StructForConditionalGeneration [[autodoc]] Pix2StructForConditionalGeneration
- forward - forward
\ No newline at end of file
...@@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin):
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config. # generation config.
for decoder_name in ("decoder", "generator"): for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict: if decoder_name in config_dict:
default_generation_config = GenerationConfig() default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name] decoder_config = config_dict[decoder_name]
......
...@@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
is_vqa=False, is_vqa=False,
tie_word_embeddings=False, tie_word_embeddings=False,
is_encoder_decoder=True,
**kwargs, **kwargs,
): ):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs)
if text_config is None: if text_config is None:
text_config = {} text_config = {}
...@@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig): ...@@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig):
self.text_config = Pix2StructTextConfig(**text_config) self.text_config = Pix2StructTextConfig(**text_config)
self.vision_config = Pix2StructVisionConfig(**vision_config) self.vision_config = Pix2StructVisionConfig(**vision_config)
self.text_config.encoder_hidden_size = self.vision_config.hidden_size
self.decoder_start_token_id = self.text_config.decoder_start_token_id self.decoder_start_token_id = self.text_config.decoder_start_token_id
self.pad_token_id = self.text_config.pad_token_id self.pad_token_id = self.text_config.pad_token_id
self.eos_token_id = self.text_config.eos_token_id
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.initializer_range = initializer_range self.initializer_range = initializer_range
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Pix2Struct modeling file""" """ Pix2Struct modeling file"""
import copy
import math import math
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ...@@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
@add_start_docstrings( @add_start_docstrings(
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
...@@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
def __init__(self, config: Pix2StructConfig): def __init__(self, config: Pix2StructConfig):
super().__init__(config) super().__init__(config)
encoder_config = copy.deepcopy(config.vision_config)
self.encoder = Pix2StructVisionModel(encoder_config)
decoder_config = copy.deepcopy(config.text_config) self.encoder = Pix2StructVisionModel(config.vision_config)
self.decoder_start_token_id = decoder_config.pad_token_id self.decoder = Pix2StructTextModel(config.text_config)
self.decoder_eos_token_ids = decoder_config.eos_token_id
self.decoder = Pix2StructTextModel(decoder_config)
self.is_vqa = config.is_vqa self.is_vqa = config.is_vqa
...@@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
Example: Example:
Inference:
```python ```python
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
...@@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> labels = "A stop sign is on the street corner."
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=labels, return_tensors="pt", add_special_tokens=True) >>> inputs = processor(images=image, return_tensors="pt")
>>> # autoregressive generation
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
A stop sign is on a street corner.
```
Training:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "A stop sign is on the street corner."
>>> inputs = processor(images=image, return_tensors="pt")
>>> labels = processor(text=text, return_tensors="pt").input_ids
>>> # forward pass >>> # forward pass
>>> outputs = model(**inputs) >>> outputs = model(**inputs, labels=labels)
>>> last_hidden_states = outputs.loss >>> loss = outputs.loss
>>> print(loss.item())
5.239729881286621
```""" ```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
@torch.no_grad() def prepare_inputs_for_generation(
def generate(
self, self,
flattened_patches: torch.FloatTensor, input_ids,
decoder_input_ids: Optional[torch.LongTensor] = None, flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
**generate_kwargs, past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
r""" if isinstance(input_ids, torch.Tensor):
Returns: # check if the first element of `input_ids` is equal to `input_ids`:
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
Example: # add `input_ids` as first token to `input_ids`
input_ids = torch.cat(
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> conditional_text = "A stop sign"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=conditional_text, return_tensors="pt", add_special_tokens=True)
>>> # forward pass
>>> outputs = model.generate(**inputs)
>>> print(processor.batch_decode(outputs, skip_special_tokens=True))
['A stop sign the street with a sign that says yes']
```"""
batch_size, _, _ = flattened_patches.shape
vision_outputs = self.encoder(flattened_patches=flattened_patches, attention_mask=attention_mask)
image_embeds = vision_outputs[0]
if isinstance(decoder_input_ids, torch.Tensor):
# check if the first element of `input_ids` is equal to `decoder_input_ids`:
if (decoder_input_ids[:, 0] != self.decoder_start_token_id).all().item():
# add `decoder_input_ids` as first token to `input_ids`
decoder_input_ids = torch.cat(
[ [
torch.ones((decoder_input_ids.shape[0], 1), dtype=torch.long, device=decoder_input_ids.device) torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
* self.decoder_start_token_id, * self.config.decoder_start_token_id,
decoder_input_ids, input_ids,
], ],
dim=-1, dim=-1,
) )
...@@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): ...@@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
], ],
dim=-1, dim=-1,
) )
elif decoder_input_ids is None: elif input_ids is None:
decoder_input_ids = ( batch_size = flattened_patches.shape[0]
torch.LongTensor([[self.decoder_start_token_id]]).repeat(batch_size, 1).to(image_embeds.device) input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(image_embeds.device) decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
outputs = self.decoder.generate( # cut decoder_input_ids if past is used
input_ids=decoder_input_ids, if past_key_values is not None:
attention_mask=decoder_attention_mask, input_ids = input_ids[:, -1:]
encoder_hidden_states=image_embeds,
encoder_attention_mask=attention_mask,
**generate_kwargs,
)
return outputs return {
"flattened_patches": flattened_patches,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
...@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder: expected_arg_names = [
expected_arg_names = [ "flattened_patches",
"input_ids", "attention_mask",
"attention_mask", "decoder_input_ids",
"decoder_input_ids", "decoder_attention_mask",
"decoder_attention_mask", "head_mask",
] "decoder_head_mask",
expected_arg_names.extend( "cross_attn_head_mask",
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] "encoder_outputs",
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names "past_key_values",
else ["encoder_outputs"] "labels",
) "decoder_inputs_embeds",
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) "use_cache",
else: ]
expected_arg_names = (
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
)
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
...@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase): ...@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
) )
def test_vqa_model(self): def test_vqa_model(self):
model_id = "ybelkada/pix2struct-ai2d-base" model_id = "google/pix2struct-ai2d-base"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(image_url, stream=True).raw) image = Image.open(requests.get(image_url, stream=True).raw)
...@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase): ...@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud") self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
def test_vqa_model_batched(self): def test_vqa_model_batched(self):
model_id = "ybelkada/pix2struct-ai2d-base" model_id = "google/pix2struct-ai2d-base"
image_urls = [ image_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
......
...@@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py ...@@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py
src/transformers/models/pegasus/tokenization_pegasus_fast.py src/transformers/models/pegasus/tokenization_pegasus_fast.py
src/transformers/models/perceiver/tokenization_perceiver.py src/transformers/models/perceiver/tokenization_perceiver.py
src/transformers/models/phobert/tokenization_phobert.py src/transformers/models/phobert/tokenization_phobert.py
src/transformers/models/pix2struct/modeling_pix2struct.py
src/transformers/models/plbart/tokenization_plbart.py src/transformers/models/plbart/tokenization_plbart.py
src/transformers/models/prophetnet/tokenization_prophetnet.py src/transformers/models/prophetnet/tokenization_prophetnet.py
src/transformers/models/rag/tokenization_rag.py src/transformers/models/rag/tokenization_rag.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment