Unverified Commit eeb9264a authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support training with a local image folder (#152)

* Support training with an image folder

* style
parent b6447fa8
...@@ -22,7 +22,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset: ...@@ -22,7 +22,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset="huggan/flowers-102-categories" \ --dataset_name="huggan/flowers-102-categories" \
--resolution=64 \ --resolution=64 \
--output_dir="ddpm-ema-flowers-64" \ --output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \ --train_batch_size=16 \
...@@ -46,7 +46,7 @@ The command to train a DDPM UNet model on the Pokemon dataset: ...@@ -46,7 +46,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
```bash ```bash
accelerate launch train_unconditional.py \ accelerate launch train_unconditional.py \
--dataset="huggan/pokemon" \ --dataset_name="huggan/pokemon" \
--resolution=64 \ --resolution=64 \
--output_dir="ddpm-ema-pokemon-64" \ --output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \ --train_batch_size=16 \
...@@ -62,3 +62,68 @@ An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64 ...@@ -62,3 +62,68 @@ An example trained model: https://huggingface.co/anton-l/ddpm-ema-pokemon-64
A full training run takes 2 hours on 4xV100 GPUs. A full training run takes 2 hours on 4xV100 GPUs.
<img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" /> <img src="https://user-images.githubusercontent.com/26864830/180248200-928953b4-db38-48db-b0c6-8b740fe6786f.png" width="700" />
### Using your own data
To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
Below, we explain both in more detail.
#### Provide the dataset as a folder
If you provide your own folders with images, the script expects the following directory structure:
```bash
data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png
```
In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:
```bash
accelerate launch train_unconditional.py \
--train_data_dir <path-to-train-directory> \
<other-arguments>
```
Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects.
#### Upload your data to the hub, as a (possibly private) repo
It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following:
```python
from datasets import load_dataset
# example 1: local folder
dataset = load_dataset("imagefolder", data_dir="path_to_your_folder")
# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="path_to_zip_file")
# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd)
dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip")
# example 4: providing several splits
dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]})
```
`ImageFolder` will create an `image` column containing the PIL-encoded images.
Next, push it to the hub!
```python
# assuming you have ran the huggingface-cli login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
dataset.push_to_hub("name_of_your_dataset", private=True)
```
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
...@@ -75,7 +75,17 @@ def main(args): ...@@ -75,7 +75,17 @@ def main(args):
Normalize([0.5], [0.5]), Normalize([0.5], [0.5]),
] ]
) )
dataset = load_dataset(args.dataset, split="train")
if args.dataset_name is not None:
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
use_auth_token=True if args.use_auth_token else None,
split="train",
)
else:
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
def transforms(examples): def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]] images = [augmentations(image.convert("RGB")) for image in examples["image"]]
...@@ -179,9 +189,12 @@ def main(args): ...@@ -179,9 +189,12 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64") parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
parser.add_argument("--overwrite_output_dir", action="store_true") parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16)
...@@ -201,6 +214,7 @@ if __name__ == "__main__": ...@@ -201,6 +214,7 @@ if __name__ == "__main__":
parser.add_argument("--ema_power", type=float, default=3 / 4) parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999) parser.add_argument("--ema_max_decay", type=float, default=0.9999)
parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--use_auth_token", action="store_true")
parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true") parser.add_argument("--hub_private_repo", action="store_true")
...@@ -222,4 +236,7 @@ if __name__ == "__main__": ...@@ -222,4 +236,7 @@ if __name__ == "__main__":
if env_local_rank != -1 and env_local_rank != args.local_rank: if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank args.local_rank = env_local_rank
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("You must specify either a dataset name from the hub or a train data directory.")
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment