image_classification.mdx 6.17 KB
Newer Older
Steven Liu's avatar
Steven Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Image classification

<Youtube id="tjAIM7BOYhw"/>

Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that represent an image. There are many uses for image classification, like detecting damage after a disaster, monitoring crop health, or helping screen medical images for signs of disease.

This guide will show you how to fine-tune [ViT](https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/vit) on the [Food-101](https://huggingface.co/datasets/food101) dataset to classify a food item in an image.

<Tip>

23
See the image classification [task page](https://huggingface.co/tasks/image-classification) for more information about its associated models, datasets, and metrics.
Steven Liu's avatar
Steven Liu committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

</Tip>

## Load Food-101 dataset

Load only the first 5000 images of the Food-101 dataset from the 馃 Datasets library since it is pretty large:

```py
>>> from datasets import load_dataset

>>> food = load_dataset("food101", split="train[:5000]")
```

Split this dataset into a train and test set:

```py
>>> food = food.train_test_split(test_size=0.2)
```

Then take a look at an example:

```py
>>> food["train"][0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
 'label': 79}
```

The `image` field contains a PIL image, and each `label` is an integer that represents a class. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:

```py
>>> labels = food["train"].features["label"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
...     label2id[label] = str(i)
...     id2label[str(i)] = label
```

Now you can convert the label number to a label name for more information:

```py
>>> id2label[str(79)]
'prime_rib'
```

Each food class - or label - corresponds to a number; `79` indicates a prime rib in the example above.

## Preprocess

Load the ViT feature extractor to process the image into a tensor:

```py
>>> from transformers import AutoFeatureExtractor

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
```

Apply several image transformations to the dataset to make the model more robust against overfitting. Here you'll use torchvision's [`transforms`](https://pytorch.org/vision/stable/transforms.html) module. Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:

```py
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
```

Create a preprocessing function that will apply the transforms and return the `pixel_values` - the inputs to the model - of the image:

```py
>>> def transforms(examples):
...     examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
...     del examples["image"]
...     return examples
```

Sylvain Gugger's avatar
Sylvain Gugger committed
98
Use 馃 Dataset's [`~datasets.Dataset.with_transform`] method to apply the transforms over the entire dataset. The transforms are applied on-the-fly when you load an element of the dataset:
Steven Liu's avatar
Steven Liu committed
99
100
101
102
103
104
105
106
107
108
109
110
111

```py
>>> food = food.with_transform(transforms)
```

Use [`DefaultDataCollator`] to create a batch of examples. Unlike other data collators in 馃 Transformers, the DefaultDataCollator does not apply additional preprocessing such as padding.

```py
>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator()
```

112
## Train
Steven Liu's avatar
Steven Liu committed
113

114
115
<frameworkcontent>
<pt>
Steven Liu's avatar
Steven Liu committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
Load ViT with [`AutoModelForImageClassification`]. Specify the number of labels, and pass the model the mapping between label number and label class:

```py
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

>>> model = AutoModelForImageClassification.from_pretrained(
...     "google/vit-base-patch16-224-in21k",
...     num_labels=len(labels),
...     id2label=id2label,
...     label2id=label2id,
... )
```

<Tip>

Steven Liu's avatar
Steven Liu committed
131
If you aren't familiar with fine-tuning a model with the [`Trainer`], take a look at the basic tutorial [here](../training#finetune-with-trainer)!
Steven Liu's avatar
Steven Liu committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

</Tip>

At this point, only three steps remain:

1. Define your training hyperparameters in [`TrainingArguments`]. It is important you don't remove unused columns because this will drop the `image` column. Without the `image` column, you can't create `pixel_values`. Set `remove_unused_columns=False` to prevent this behavior!
2. Pass the training arguments to [`Trainer`] along with the model, datasets, tokenizer, and data collator.
3. Call [`~Trainer.train`] to fine-tune your model.

```py
>>> training_args = TrainingArguments(
...     output_dir="./results",
...     per_device_train_batch_size=16,
...     evaluation_strategy="steps",
...     num_train_epochs=4,
...     fp16=True,
...     save_steps=100,
...     eval_steps=100,
...     logging_steps=10,
...     learning_rate=2e-4,
...     save_total_limit=2,
...     remove_unused_columns=False,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     data_collator=data_collator,
...     train_dataset=food["train"],
...     eval_dataset=food["test"],
...     tokenizer=feature_extractor,
... )

>>> trainer.train()
```
167
168
</pt>
</frameworkcontent>
Steven Liu's avatar
Steven Liu committed
169
170
171

<Tip>

172
For a more in-depth example of how to fine-tune a model for image classification, take a look at the corresponding [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
Steven Liu's avatar
Steven Liu committed
173

174
</Tip>