vision-encoder-decoder.md 8.28 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
<!--Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
11
12
13
14

鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

Sylvain Gugger's avatar
Sylvain Gugger committed
15
16
17
18
-->

# Vision Encoder Decoder Models

19
20
21
22
## Overview

The [`VisionEncoderDecoderModel`] can be used to initialize an image-to-text model with any
pretrained Transformer-based vision model as the encoder (*e.g.* [ViT](vit), [BEiT](beit), [DeiT](deit), [Swin](swin))
23
and any pretrained language model as the decoder (*e.g.* [RoBERTa](roberta), [GPT2](gpt2), [BERT](bert), [DistilBERT](distilbert)).
Sylvain Gugger's avatar
Sylvain Gugger committed
24
25
26
27
28

The effectiveness of initializing image-to-text-sequence models with pretrained checkpoints has been shown in (for
example) [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang,
Zhoujun Li, Furu Wei.

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
After such a [`VisionEncoderDecoderModel`] has been trained/fine-tuned, it can be saved/loaded just like any other models (see the examples below
for more information).

An example application is image captioning, in which the encoder is used to encode the image, after which an autoregressive language model generates
the caption. Another example is optical character recognition. Refer to [TrOCR](trocr), which is an instance of [`VisionEncoderDecoderModel`].

## Randomly initializing `VisionEncoderDecoderModel` from model configurations.

[`VisionEncoderDecoderModel`] can be randomly initialized from an encoder and a decoder config. In the following example, we show how to do this using the default [`ViTModel`] configuration for the encoder
and the default [`BertForCausalLM`] configuration for the decoder.

```python
>>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel

>>> config_encoder = ViTConfig()
>>> config_decoder = BertConfig()

>>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> model = VisionEncoderDecoderModel(config=config)
```

## Initialising `VisionEncoderDecoderModel` from a pretrained encoder and a pretrained decoder.

[`VisionEncoderDecoderModel`] can be initialized from a pretrained encoder checkpoint and a pretrained decoder checkpoint. Note that any pretrained Transformer-based vision model, *e.g.* [Swin](swin), can serve as the encoder and both pretrained auto-encoding models, *e.g.* BERT, pretrained causal language models, *e.g.* GPT2, as well as the pretrained decoder part of sequence-to-sequence models, *e.g.* decoder of BART, can be used as the decoder.
Depending on which architecture you choose as the decoder, the cross-attention layers might be randomly initialized.
Initializing [`VisionEncoderDecoderModel`] from a pretrained encoder and decoder checkpoint requires the model to be fine-tuned on a downstream task, as has been shown in [the *Warm-starting-encoder-decoder blog post*](https://huggingface.co/blog/warm-starting-encoder-decoder).
To do so, the `VisionEncoderDecoderModel` class provides a [`VisionEncoderDecoderModel.from_encoder_decoder_pretrained`] method.

```python
>>> from transformers import VisionEncoderDecoderModel

>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
...     "microsoft/swin-base-patch4-window7-224-in22k", "bert-base-uncased"
... )
```

## Loading an existing `VisionEncoderDecoderModel` checkpoint and perform inference.

To load fine-tuned checkpoints of the `VisionEncoderDecoderModel` class, [`VisionEncoderDecoderModel`] provides the `from_pretrained(...)` method just like any other model architecture in Transformers.

To perform inference, one uses the [`generate`] method, which allows to autoregressively generate text. This method supports various forms of decoding, such as greedy, beam search and multinomial sampling.

```python
>>> import requests
>>> from PIL import Image

75
>>> from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
76

77
>>> # load a fine-tuned image captioning model and corresponding tokenizer and image processor
78
79
>>> model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
>>> tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
80
>>> image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
81
82
83
84

>>> # let's perform inference on an image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
85
>>> pixel_values = image_processor(image, return_tensors="pt").pixel_values
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

>>> # autoregressively generate caption (uses greedy decoding by default)
>>> generated_ids = model.generate(pixel_values)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
a cat laying on a blanket next to a cat laying on a bed
```

## Loading a PyTorch checkpoint into `TFVisionEncoderDecoderModel`.

[`TFVisionEncoderDecoderModel.from_pretrained`] currently doesn't support initializing the model from a
PyTorch checkpoint. Passing `from_pt=True` to this method will throw an exception. If there are only PyTorch
checkpoints for a particular vision encoder-decoder model, a workaround is:

```python
>>> from transformers import VisionEncoderDecoderModel, TFVisionEncoderDecoderModel

>>> _model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")

>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
...     "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
```

## Training

Once the model is created, it can be fine-tuned similar to BART, T5 or any other encoder-decoder model on a dataset of (image, text) pairs.
As you can see, only 2 inputs are required for the model in order to compute a loss: `pixel_values` (which are the
images) and `labels` (which are the `input_ids` of the encoded target sequence).

```python
122
>>> from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel
123
124
>>> from datasets import load_dataset

125
>>> image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
126
127
128
129
130
131
132
133
134
135
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
...     "google/vit-base-patch16-224-in21k", "bert-base-uncased"
... )

>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
136
>>> pixel_values = image_processor(image, return_tensors="pt").pixel_values
137
138
139
140
141
142
143
144
145

>>> labels = tokenizer(
...     "an image of two cats chilling on a couch",
...     return_tensors="pt",
... ).input_ids

>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(pixel_values=pixel_values, labels=labels).loss
```
Sylvain Gugger's avatar
Sylvain Gugger committed
146

147
148
This model was contributed by [nielsr](https://github.com/nielsrogge). This model's TensorFlow and Flax versions
were contributed by [ydshieh](https://github.com/ydshieh).
Sylvain Gugger's avatar
Sylvain Gugger committed
149
150
151
152
153

## VisionEncoderDecoderConfig

[[autodoc]] VisionEncoderDecoderConfig

154
155
156
<frameworkcontent>
<pt>

Sylvain Gugger's avatar
Sylvain Gugger committed
157
158
159
160
161
162
## VisionEncoderDecoderModel

[[autodoc]] VisionEncoderDecoderModel
    - forward
    - from_encoder_decoder_pretrained

163
164
165
</pt>
<tf>

166
167
168
169
170
171
## TFVisionEncoderDecoderModel

[[autodoc]] TFVisionEncoderDecoderModel
    - call
    - from_encoder_decoder_pretrained

172
173
174
</tf>
<jax>

Sylvain Gugger's avatar
Sylvain Gugger committed
175
176
177
178
179
## FlaxVisionEncoderDecoderModel

[[autodoc]] FlaxVisionEncoderDecoderModel
    - __call__
    - from_encoder_decoder_pretrained
180
181
182

</jax>
</frameworkcontent>