Unverified Commit 048443db authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve image classification example (#16585)



* Improve README

* Make dataset_name argument optional

* Improve local data

* Fix bug

* Improve README some more

* Apply suggestions from code review

* Improve README
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 3e4eec47
...@@ -14,13 +14,18 @@ See the License for the specific language governing permissions and ...@@ -14,13 +14,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
--> -->
# Image classification examples # Image classification example
The following examples showcase how to fine-tune a `ViT` for image-classification using PyTorch. This directory contains a script, `run_image_classification.py`, that showcases how to fine-tune any model supported by the [`AutoModelForImageClassification` API](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForImageClassification) (such as [ViT](https://huggingface.co/docs/transformers/main/en/model_doc/vit), [ConvNeXT]((https://huggingface.co/docs/transformers/main/en/model_doc/convnext)), [ResNet]((https://huggingface.co/docs/transformers/main/en/model_doc/resnet)), [Swin Transformer]((https://huggingface.co/docs/transformers/main/en/model_doc/swin))...) using PyTorch. It can be used to fine-tune models on both well-known datasets (like [CIFAR-10](https://huggingface.co/datasets/cifar10), [Fashion MNIST](https://huggingface.co/datasets/fashion_mnist), ...) as well as on your own custom data.
## Using datasets from 🤗 `datasets` This page includes 2 sections:
- [Using datasets from the hub](#using-datasets-from-🤗-hub)
- [Using your own data](#using-your-own-data).
Here we show how to fine-tune a `ViT` on the [beans](https://huggingface.co/datasets/beans) dataset.
## Using datasets from 🤗 `Hub`
Here we show how to fine-tune a Vision Transformer (`ViT`) on the [beans](https://huggingface.co/datasets/beans) dataset, to classify the disease type of bean leaves.
👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans). 👀 See the results here: [nateraw/vit-base-beans](https://huggingface.co/nateraw/vit-base-beans).
...@@ -46,36 +51,21 @@ python run_image_classification.py \ ...@@ -46,36 +51,21 @@ python run_image_classification.py \
--seed 1337 --seed 1337
``` ```
Here we show how to fine-tune a `ViT` on the [cats_vs_dogs](https://huggingface.co/datasets/cats_vs_dogs) dataset. To fine-tune another model, simply provide the `--model_name_or_path` argument. To train on another dataset, simply set the `--dataset_name` argument.
👀 See the results here: [nateraw/vit-base-cats-vs-dogs](https://huggingface.co/nateraw/vit-base-cats-vs-dogs). 👀 See the results here: [nateraw/vit-base-cats-vs-dogs](https://huggingface.co/nateraw/vit-base-cats-vs-dogs).
```bash
python run_image_classification.py \
--dataset_name cats_vs_dogs \
--output_dir ./cats_vs_dogs_outputs/ \
--remove_unused_columns False \
--do_train \
--do_eval \
--push_to_hub \
--push_to_hub_model_id vit-base-cats-vs-dogs \
--fp16 True \
--learning_rate 2e-4 \
--num_train_epochs 5 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--logging_strategy steps \
--logging_steps 10 \
--evaluation_strategy epoch \
--save_strategy epoch \
--load_best_model_at_end True \
--save_total_limit 3 \
--seed 1337
```
## Using your own data ## Using your own data
To use your own dataset, the training script expects the following directory structure: To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
Below, we explain both in more detail.
### Provide them as folders
If you provide your own folders with images, the script expects the following directory structure:
```bash ```bash
root/dog/xxx.png root/dog/xxx.png
...@@ -87,11 +77,10 @@ root/cat/nsdf3.png ...@@ -87,11 +77,10 @@ root/cat/nsdf3.png
root/cat/[...]/asd932_.png root/cat/[...]/asd932_.png
``` ```
Once you've prepared your dataset, you can can run the script like this: In other words, you need to organize your images in subfolders, based on their class. You can then run the script like this:
```bash ```bash
python run_image_classification.py \ python run_image_classification.py \
--dataset_name nateraw/image-folder \
--train_dir <path-to-train-root> \ --train_dir <path-to-train-root> \
--output_dir ./outputs/ \ --output_dir ./outputs/ \
--remove_unused_columns False \ --remove_unused_columns False \
...@@ -99,12 +88,48 @@ python run_image_classification.py \ ...@@ -99,12 +88,48 @@ python run_image_classification.py \
--do_eval --do_eval
``` ```
### 💡 The above will split the train dir into training and evaluation sets Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
#### 💡 The above will split the train dir into training and evaluation sets
- To control the split amount, use the `--train_val_split` flag. - To control the split amount, use the `--train_val_split` flag.
- To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag. - To provide your own validation split in its own directory, you can pass the `--validation_dir <path-to-val-root>` flag.
### Upload your data to the hub, as a (possibly private) repo
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
```python
from datasets import load_dataset
# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
# example 3: remote files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
# example 4: providing several splits
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
```
`ImageFolder` will create a `label` column, and the label name is based on the directory name.
## Sharing your model on 🤗 Hub Next, push it to the hub!
```python
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```
and that's it! You can now simply train your model simply by setting the `--dataset_name` argument to the name of your dataset on the hub (as explained in [Using datasets from the hub](#using-datasets-from-🤗-hub)).
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
# Sharing your model on 🤗 Hub
0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account 0. If you haven't already, [sign up](https://huggingface.co/join) for a 🤗 account
...@@ -116,13 +141,21 @@ $ git config --global user.email "you@example.com" ...@@ -116,13 +141,21 @@ $ git config --global user.email "you@example.com"
$ git config --global user.name "Your Name" $ git config --global user.name "Your Name"
``` ```
2. Log in with your HuggingFace account credentials using `huggingface-cli` 2. Log in with your HuggingFace account credentials using `huggingface-cli`:
```bash ```bash
$ huggingface-cli login $ huggingface-cli login
# ...follow the prompts # ...follow the prompts
``` ```
or, in case you're running in a notebook:
```python
from huggingface_hub import notebook_login
notebook_login()
```
3. When running the script, pass the following arguments: 3. When running the script, pass the following arguments:
```bash ```bash
......
...@@ -72,13 +72,15 @@ def pil_loader(path: str): ...@@ -72,13 +72,15 @@ def pil_loader(path: str):
class DataTrainingArguments: class DataTrainingArguments:
""" """
Arguments pertaining to what data we are going to input our model for training and eval. Arguments pertaining to what data we are going to input our model for training and eval.
Using `HfArgumentParser` we can turn this class Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify
into argparse arguments to be able to specify them on them on the command line.
the command line.
""" """
dataset_name: Optional[str] = field( dataset_name: Optional[str] = field(
default="nateraw/image-folder", metadata={"help": "Name of a dataset from the datasets package"} default=None,
metadata={
"help": "Name of a dataset from the hub (could be your own, possibly private dataset hosted on the hub)."
},
) )
dataset_config_name: Optional[str] = field( dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
...@@ -104,12 +106,10 @@ class DataTrainingArguments: ...@@ -104,12 +106,10 @@ class DataTrainingArguments:
) )
def __post_init__(self): def __post_init__(self):
data_files = dict() if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
if self.train_dir is not None: raise ValueError(
data_files["train"] = self.train_dir "You must specify either a dataset name from the hub or a train and/or validation directory."
if self.validation_dir is not None: )
data_files["val"] = self.validation_dir
self.data_files = data_files if data_files else None
@dataclass @dataclass
...@@ -201,25 +201,37 @@ def main(): ...@@ -201,25 +201,37 @@ def main():
) )
# Initialize our dataset and prepare it for the 'image-classification' task. # Initialize our dataset and prepare it for the 'image-classification' task.
ds = load_dataset( if data_args.dataset_name is not None:
dataset = load_dataset(
data_args.dataset_name, data_args.dataset_name,
data_args.dataset_config_name, data_args.dataset_config_name,
data_files=data_args.data_files,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
task="image-classification", task="image-classification",
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
else:
data_files = {}
if data_args.train_dir is not None:
data_files["train"] = os.path.join(data_args.train_dir, "**")
if data_args.validation_dir is not None:
data_files["validation"] = os.path.join(data_args.validation_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
task="image-classification",
)
# If we don't have a validation split, split off a percentage of train as validation. # If we don't have a validation split, split off a percentage of train as validation.
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
split = ds["train"].train_test_split(data_args.train_val_split) split = dataset["train"].train_test_split(data_args.train_val_split)
ds["train"] = split["train"] dataset["train"] = split["train"]
ds["validation"] = split["test"] dataset["validation"] = split["test"]
# Prepare label mappings. # Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
labels = ds["train"].features["labels"].names labels = dataset["train"].features["labels"].names
label2id, id2label = dict(), dict() label2id, id2label = dict(), dict()
for i, label in enumerate(labels): for i, label in enumerate(labels):
label2id[label] = str(i) label2id[label] = str(i)
...@@ -291,29 +303,31 @@ def main(): ...@@ -291,29 +303,31 @@ def main():
return example_batch return example_batch
if training_args.do_train: if training_args.do_train:
if "train" not in ds: if "train" not in dataset:
raise ValueError("--do_train requires a train dataset") raise ValueError("--do_train requires a train dataset")
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) dataset["train"] = (
dataset["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
)
# Set the training transforms # Set the training transforms
ds["train"].set_transform(train_transforms) dataset["train"].set_transform(train_transforms)
if training_args.do_eval: if training_args.do_eval:
if "validation" not in ds: if "validation" not in dataset:
raise ValueError("--do_eval requires a validation dataset") raise ValueError("--do_eval requires a validation dataset")
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
ds["validation"] = ( dataset["validation"] = (
ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) dataset["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
) )
# Set the validation transforms # Set the validation transforms
ds["validation"].set_transform(val_transforms) dataset["validation"].set_transform(val_transforms)
# Initalize our trainer # Initalize our trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
args=training_args, args=training_args,
train_dataset=ds["train"] if training_args.do_train else None, train_dataset=dataset["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None, eval_dataset=dataset["validation"] if training_args.do_eval else None,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
tokenizer=feature_extractor, tokenizer=feature_extractor,
data_collator=collate_fn, data_collator=collate_fn,
...@@ -343,7 +357,7 @@ def main(): ...@@ -343,7 +357,7 @@ def main():
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"tasks": "image-classification", "tasks": "image-classification",
"dataset": data_args.dataset_name, "dataset": data_args.dataset_name,
"tags": ["image-classification"], "tags": ["image-classification", "vision"],
} }
if training_args.push_to_hub: if training_args.push_to_hub:
trainer.push_to_hub(**kwargs) trainer.push_to_hub(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment