"vscode:/vscode.git/clone" did not exist on "50931d69465c2a7c255cb321ee87283403189565"
Unverified Commit d27e996c authored by Isamu Isozaki's avatar Isamu Isozaki Committed by GitHub
Browse files

Adding VQGAN Training script (#5483)



* Init commit

* Removed einops

* Added default movq config for training

* Update explanation of prompts

* Fixed inheritance of discriminator and init_tracker

* Fixed incompatible api between muse and here

* Fixed output

* Setup init training

* Basic structure done

* Removed attention for quick tests

* Style fixes

* Fixed vae/vqgan styles

* Removed redefinition of wandb

* Fixed log_validation and tqdm

* Nothing commit

* Added commit loss to lookup_from_codebook

* Update src/diffusers/models/vq_model.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Adding perliminary README

* Fixed one typo

* Local changes

* Fixed main issues

* Merging

* Update src/diffusers/models/vq_model.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Testing+Fixed bugs in training script

* Some style fixes

* Added wandb to docs

* Fixed timm test

* get testing suite ready.

* remove return loss

* remove return_loss

* Remove diffs

* Remove diffs

* fix ruff format

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 72780ff5
...@@ -156,7 +156,7 @@ jobs: ...@@ -156,7 +156,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }} if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: | run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
examples examples
......
...@@ -426,6 +426,7 @@ jobs: ...@@ -426,6 +426,7 @@ jobs:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
run: | run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/ python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports - name: Failure short reports
......
...@@ -107,7 +107,7 @@ jobs: ...@@ -107,7 +107,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_examples' }} if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: | run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft python -m uv pip install peft timm
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
examples examples
......
## Training an VQGAN VAE
VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
### Training on CIFAR10
The command to train a VQGAN model on cifar10 dataset:
```bash
accelerate launch train_vqgan.py \
--dataset_name=cifar10 \
--image_column=img \
--validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg \
--resolution=128 \
--train_batch_size=2 \
--gradient_accumulation_steps=8 \
--report_to=wandb
```
An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.
# Modifying the architecture
To modify the architecture of the vqgan model you can save the config taken from [here](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder/blob/main/movq/config.json) and then provide that to the script with the option --model_config_name_or_path. This config is below
```
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"AttnDownEncoderBlock2D"
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
```
To lower the amount of layers in a VQGan, you can remove layers by modifying the block_out_channels, down_block_types, and up_block_types like below
```
{
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
256,
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 16384,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"vq_embed_dim": 4
}
```
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
## Extra training tips/ideas
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
Secondly, training should seem to be done when both the discriminator and the generator loss converges.
Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.
\ No newline at end of file
"""
Ported from Paella
"""
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
# Discriminator model ported from Paella https://github.com/dome272/Paella/blob/main/src_distributed/vqgan.py
class Discriminator(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, in_channels=3, cond_channels=0, hidden_channels=512, depth=6):
super().__init__()
d = max(depth - 3, 3)
layers = [
nn.utils.spectral_norm(
nn.Conv2d(in_channels, hidden_channels // (2**d), kernel_size=3, stride=2, padding=1)
),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = hidden_channels // (2 ** max((d - i), 0))
c_out = hidden_channels // (2 ** max((d - 1 - i), 0))
layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
self.shuffle = nn.Conv2d(
(hidden_channels + cond_channels) if cond_channels > 0 else hidden_channels, 1, kernel_size=1
)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
x = self.encoder(x)
if cond is not None:
cond = cond.view(
cond.size(0),
cond.size(1),
1,
1,
).expand(-1, -1, x.size(-2), x.size(-1))
x = torch.cat([x, cond], dim=1)
x = self.shuffle(x)
x = self.logits(x)
return x
accelerate>=0.16.0
torchvision
transformers>=4.25.1
datasets
timm
numpy
tqdm
tensorboard
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import shutil
import sys
import tempfile
import torch
from diffusers import VQModel
from diffusers.utils.testing_utils import require_timm
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@require_timm
class TextToImage(ExamplesTestsAccelerate):
@property
def test_vqmodel_config(self):
return {
"_class_name": "VQModel",
"_diffusers_version": "0.17.0.dev0",
"act_fn": "silu",
"block_out_channels": [
32,
],
"down_block_types": [
"DownEncoderBlock2D",
],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"norm_type": "spatial",
"num_vq_embeddings": 32,
"out_channels": 3,
"sample_size": 32,
"scaling_factor": 0.18215,
"up_block_types": [
"UpDecoderBlock2D",
],
"vq_embed_dim": 4,
}
@property
def test_discriminator_config(self):
return {
"_class_name": "Discriminator",
"_diffusers_version": "0.27.0.dev0",
"in_channels": 3,
"cond_channels": 0,
"hidden_channels": 8,
"depth": 4,
}
def get_vq_and_discriminator_configs(self, tmpdir):
vqmodel_config_path = os.path.join(tmpdir, "vqmodel.json")
discriminator_config_path = os.path.join(tmpdir, "discriminator.json")
with open(vqmodel_config_path, "w") as fp:
json.dump(self.test_vqmodel_config, fp)
with open(discriminator_config_path, "w") as fp:
json.dump(self.test_discriminator_config, fp)
return vqmodel_config_path, discriminator_config_path
def test_vqmodel(self):
with tempfile.TemporaryDirectory() as tmpdir:
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
test_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(
os.path.isfile(os.path.join(tmpdir, "discriminator", "diffusion_pytorch_model.safetensors"))
)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "vqmodel", "diffusion_pytorch_model.safetensors")))
def test_vqmodel_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4"},
)
# Run training script for 2 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
# In the current script, checkpointing_steps 1 is equivalent to checkpointing_steps 2 as after the generator gets trained for one step,
# the discriminator gets trained and loss and saving happens after that. Thus we do not expect to get a checkpoint-5
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_vqmodel_checkpointing_use_ema(self):
with tempfile.TemporaryDirectory() as tmpdir:
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=2
--output_dir {tmpdir}
--use_ema
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint
model = VQModel.from_pretrained(tmpdir, subfolder="checkpoint-2/vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 2 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--output_dir {tmpdir}
--use_ema
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_vqmodel_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
# Run training script with checkpointing
# max_train_steps == 6, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# check checkpoint directories exist
# checkpoint-2 should have been deleted
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
vqmodel_config_path, discriminator_config_path = self.get_vq_and_discriminator_configs(tmpdir)
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# resume and we should try to checkpoint at 6, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
examples/vqgan/train_vqgan.py
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 32
--image_column image
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 8
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
--checkpoints_total_limit=2
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
model = VQModel.from_pretrained(tmpdir, subfolder="vqmodel")
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
_ = model(image)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8"},
)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
import shutil
import time
from pathlib import Path
import accelerate
import numpy as np
import PIL
import PIL.Image
import timm
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
from datasets import load_dataset
from discriminator import Discriminator
from huggingface_hub import create_repo
from packaging import version
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision import transforms
from tqdm import tqdm
from diffusers import VQModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
logger = get_logger(__name__, log_level="INFO")
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def _map_layer_to_idx(backbone, layers, offset=0):
"""Maps set of layer names to indices of model. Ported from anomalib
Returns:
Feature map extracted from the CNN
"""
idx = []
features = timm.create_model(
backbone,
pretrained=False,
features_only=False,
exportable=True,
)
for i in layers:
try:
idx.append(list(dict(features.named_children()).keys()).index(i) - offset)
except ValueError:
raise ValueError(
f"Layer {i} not found in model {backbone}. Select layer from {list(dict(features.named_children()).keys())}. The network architecture is {features}"
)
return idx
def get_perceptual_loss(pixel_values, fmap, timm_model, timm_model_resolution, timm_model_normalization):
img_timm_model_input = timm_model_normalization(F.interpolate(pixel_values, timm_model_resolution))
fmap_timm_model_input = timm_model_normalization(F.interpolate(fmap, timm_model_resolution))
if pixel_values.shape[1] == 1:
# handle grayscale for timm_model
img_timm_model_input, fmap_timm_model_input = (
t.repeat(1, 3, 1, 1) for t in (img_timm_model_input, fmap_timm_model_input)
)
img_timm_model_feats = timm_model(img_timm_model_input)
recon_timm_model_feats = timm_model(fmap_timm_model_input)
perceptual_loss = F.mse_loss(img_timm_model_feats[0], recon_timm_model_feats[0])
for i in range(1, len(img_timm_model_feats)):
perceptual_loss += F.mse_loss(img_timm_model_feats[i], recon_timm_model_feats[i])
perceptual_loss /= len(img_timm_model_feats)
return perceptual_loss
def grad_layer_wrt_loss(loss, layer):
return torch.autograd.grad(
outputs=loss,
inputs=layer,
grad_outputs=torch.ones_like(loss),
retain_graph=True,
)[0].detach()
def gradient_penalty(images, output, weight=10):
gradients = torch.autograd.grad(
outputs=output,
inputs=images,
grad_outputs=torch.ones(output.size(), device=images.device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
bsz = gradients.shape[0]
gradients = torch.reshape(gradients, (bsz, -1))
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
@torch.no_grad()
def log_validation(model, args, validation_transform, accelerator, global_step):
logger.info("Generating images...")
dtype = torch.float32
if accelerator.mixed_precision == "fp16":
dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
dtype = torch.bfloat16
original_images = []
for image_path in args.validation_images:
image = PIL.Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = validation_transform(image).to(accelerator.device, dtype=dtype)
original_images.append(image[None])
# Generate images
model.eval()
images = []
for original_image in original_images:
image = accelerator.unwrap_model(model)(original_image).sample
images.append(image)
model.train()
original_images = torch.cat(original_images, dim=0)
images = torch.cat(images, dim=0)
# Convert to PIL images
images = torch.clamp(images, 0.0, 1.0)
original_images = torch.clamp(original_images, 0.0, 1.0)
images *= 255.0
original_images *= 255.0
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
original_images = original_images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
images = np.concatenate([original_images, images], axis=2)
images = [Image.fromarray(image) for image in images]
# Log images
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: Original, Generated") for i, image in enumerate(images)
]
},
step=global_step,
)
torch.cuda.empty_cache()
return images
def log_grad_norm(model, accelerator, global_step):
for name, param in model.named_parameters():
if param.grad is not None:
grads = param.grad.detach().data
grad_norm = (grads.norm(p=2) / grads.numel()).item()
accelerator.log({"grad_norm/" + name: grad_norm}, step=global_step)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--log_grad_norm_steps",
type=int,
default=500,
help=("Print logs of gradient norms every X steps."),
)
parser.add_argument(
"--log_steps",
type=int,
default=50,
help=("Print logs every X steps."),
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run validation every X steps. Validation consists of running reconstruction on images in"
" `args.validation_images` and logging the reconstructed images."
),
)
parser.add_argument(
"--vae_loss",
type=str,
default="l2",
help="The loss function for vae reconstruction loss.",
)
parser.add_argument(
"--timm_model_offset",
type=int,
default=0,
help="Offset of timm layers to indices.",
)
parser.add_argument(
"--timm_model_layers",
type=str,
default="head",
help="The layers to get output from in the timm model.",
)
parser.add_argument(
"--timm_model_backend",
type=str,
default="vgg19",
help="Timm model used to get the lpips loss",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--model_config_name_or_path",
type=str,
default=None,
help="The config of the Vq model to train, leave as None to use standard Vq model configuration.",
)
parser.add_argument(
"--discriminator_config_name_or_path",
type=str,
default=None,
help="The config of the discriminator model to train, leave as None to use standard Vq model configuration.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--validation_images",
type=str,
default=None,
nargs="+",
help=("A set of validation images evaluated every `--validation_steps` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
default="vqgan-output",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--discr_learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--discr_lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
"--non_ema_revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
" remote repository specified with --pretrained_model_name_or_path."
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--prediction_type",
type=str,
default=None,
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="vqgan-training",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
# default to using the same revision for the non-ema model if not specified
if args.non_ema_revision is None:
args.non_ema_revision = args.revision
return args
def main():
#########################
# SETUP Accelerator #
#########################
args = parse_args()
# Enable TF32 on Ampere GPUs
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.distributed_type == DistributedType.DEEPSPEED:
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
#####################################
# SETUP LOGGING, SEED and CONFIG #
#####################################
if accelerator.is_main_process:
tracker_config = dict(vars(args))
tracker_config.pop("validation_images")
accelerator.init_trackers(args.tracker_project_name, tracker_config)
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
#########################
# MODELS and OPTIMIZER #
#########################
logger.info("Loading models and optimizer")
if args.model_config_name_or_path is None and args.pretrained_model_name_or_path is None:
# Taken from config of movq at kandinsky-community/kandinsky-2-2-decoder but without the attention layers
model = VQModel(
act_fn="silu",
block_out_channels=[
128,
256,
512,
],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
],
in_channels=3,
latent_channels=4,
layers_per_block=2,
norm_num_groups=32,
norm_type="spatial",
num_vq_embeddings=16384,
out_channels=3,
sample_size=32,
scaling_factor=0.18215,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
vq_embed_dim=4,
)
elif args.pretrained_model_name_or_path is not None:
model = VQModel.from_pretrained(args.pretrained_model_name_or_path)
else:
config = VQModel.load_config(args.model_config_name_or_path)
model = VQModel.from_config(config)
if args.use_ema:
ema_model = EMAModel(model.parameters(), model_cls=VQModel, model_config=model.config)
if args.discriminator_config_name_or_path is None:
discriminator = Discriminator()
else:
config = Discriminator.load_config(args.discriminator_config_name_or_path)
discriminator = Discriminator.from_config(config)
idx = _map_layer_to_idx(args.timm_model_backend, args.timm_model_layers.split("|"), args.timm_model_offset)
timm_model = timm.create_model(
args.timm_model_backend,
pretrained=True,
features_only=True,
exportable=True,
out_indices=idx,
)
timm_model = timm_model.to(accelerator.device)
timm_model.requires_grad = False
timm_model.eval()
timm_transform = create_transform(**resolve_data_config(timm_model.pretrained_cfg, model=timm_model))
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
assert isinstance(
timm_centercrop_transform, transforms.CenterCrop
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
assert isinstance(
timm_model_normalization, transforms.Normalize
), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
if args.enable_xformers_memory_efficient_attention:
model.enable_xformers_memory_efficient_attention()
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
if args.use_ema:
ema_model.save_pretrained(os.path.join(output_dir, "vqmodel_ema"))
vqmodel = models[0]
discriminator = models[1]
vqmodel.save_pretrained(os.path.join(output_dir, "vqmodel"))
discriminator.save_pretrained(os.path.join(output_dir, "discriminator"))
weights.pop()
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vqmodel_ema"), VQModel)
ema_model.load_state_dict(load_model.state_dict())
ema_model.to(accelerator.device)
del load_model
discriminator = models.pop()
load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator")
discriminator.register_to_config(**load_model.config)
discriminator.load_state_dict(load_model.state_dict())
del load_model
vqmodel = models.pop()
load_model = VQModel.from_pretrained(input_dir, subfolder="vqmodel")
vqmodel.register_to_config(**load_model.config)
vqmodel.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
learning_rate = args.learning_rate
if args.scale_lr:
learning_rate = (
learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
list(model.parameters()),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
discr_optimizer = optimizer_cls(
list(discriminator.parameters()),
lr=args.discr_learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
##################################
# DATLOADER and LR-SCHEDULER #
#################################
logger.info("Creating dataloaders and lr_scheduler")
args.train_batch_size * accelerator.num_processes
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
# DataLoaders creation:
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
assert args.image_column is not None
image_column = args.image_column
if image_column not in column_names:
raise ValueError(f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}")
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
]
)
validation_transform = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
return {"pixel_values": pixel_values}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_training_steps=args.max_train_steps,
num_warmup_steps=args.lr_warmup_steps,
)
discr_lr_scheduler = get_scheduler(
args.discr_lr_scheduler,
optimizer=discr_optimizer,
num_training_steps=args.max_train_steps,
num_warmup_steps=args.lr_warmup_steps,
)
# Prepare everything with accelerator
logger.info("Preparing model, optimizer and dataloaders")
# The dataloader are already aware of distributed training, so we don't need to prepare them.
model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler = accelerator.prepare(
model, discriminator, optimizer, discr_optimizer, lr_scheduler, discr_lr_scheduler
)
if args.use_ema:
ema_model.to(accelerator.device)
# Train!
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Potentially load in the weights and states from a previous save
resume_from_checkpoint = args.resume_from_checkpoint
if resume_from_checkpoint:
if resume_from_checkpoint != "latest":
path = resume_from_checkpoint
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
path = os.path.join(args.output_dir, path)
if path is None:
accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.")
resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(path)
accelerator.wait_for_everyone()
global_step = int(os.path.basename(path).split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
# As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to
# reuse the same training loop with other datasets/loaders.
avg_gen_loss, avg_discr_loss = None, None
for epoch in range(first_epoch, args.num_train_epochs):
model.train()
discriminator.train()
for i, batch in enumerate(train_dataloader):
pixel_values = batch["pixel_values"]
pixel_values = pixel_values.to(accelerator.device, non_blocking=True)
data_time_m.update(time.time() - end)
generator_step = ((i // args.gradient_accumulation_steps) % 2) == 0
# Train Step
# The behavior of accelerator.accumulate is to
# 1. Check if gradients are synced(reached gradient-accumulation_steps)
# 2. If so sync gradients by stopping the not syncing process
if generator_step:
optimizer.zero_grad(set_to_none=True)
else:
discr_optimizer.zero_grad(set_to_none=True)
# encode images to the latent space and get the commit loss from vq tokenization
# Return commit loss
fmap, commit_loss = model(pixel_values, return_dict=False)
if generator_step:
with accelerator.accumulate(model):
# reconstruction loss. Pixel level differences between input vs output
if args.vae_loss == "l2":
loss = F.mse_loss(pixel_values, fmap)
else:
loss = F.l1_loss(pixel_values, fmap)
# perceptual loss. The high level feature mean squared error loss
perceptual_loss = get_perceptual_loss(
pixel_values,
fmap,
timm_model,
timm_model_resolution=timm_model_resolution,
timm_model_normalization=timm_model_normalization,
)
# generator loss
gen_loss = -discriminator(fmap).mean()
last_dec_layer = accelerator.unwrap_model(model).decoder.conv_out.weight
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
adaptive_weight = norm_grad_wrt_perceptual_loss / norm_grad_wrt_gen_loss.clamp(min=1e-8)
adaptive_weight = adaptive_weight.clamp(max=1e4)
loss += commit_loss
loss += perceptual_loss
loss += adaptive_weight * gen_loss
# Gather the losses across all processes for logging (if we use distributed training).
avg_gen_loss = accelerator.gather(loss.repeat(args.train_batch_size)).float().mean()
accelerator.backward(loss)
if args.max_grad_norm is not None and accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
# log gradient norm before zeroing it
if (
accelerator.sync_gradients
and global_step % args.log_grad_norm_steps == 0
and accelerator.is_main_process
):
log_grad_norm(model, accelerator, global_step)
else:
# Return discriminator loss
with accelerator.accumulate(discriminator):
fmap.detach_()
pixel_values.requires_grad_()
real = discriminator(pixel_values)
fake = discriminator(fmap)
loss = (F.relu(1 + fake) + F.relu(1 - real)).mean()
gp = gradient_penalty(pixel_values, real)
loss += gp
avg_discr_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
accelerator.backward(loss)
if args.max_grad_norm is not None and accelerator.sync_gradients:
accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm)
discr_optimizer.step()
discr_lr_scheduler.step()
if (
accelerator.sync_gradients
and global_step % args.log_grad_norm_steps == 0
and accelerator.is_main_process
):
log_grad_norm(discriminator, accelerator, global_step)
batch_time_m.update(time.time() - end)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
if args.use_ema:
ema_model.step(model.parameters())
if accelerator.sync_gradients and not generator_step and accelerator.is_main_process:
# wait for both generator and discriminator to settle
# Log metrics
if global_step % args.log_steps == 0:
samples_per_second_per_gpu = (
args.gradient_accumulation_steps * args.train_batch_size / batch_time_m.val
)
logs = {
"step_discr_loss": avg_discr_loss.item(),
"lr": lr_scheduler.get_last_lr()[0],
"samples/sec/gpu": samples_per_second_per_gpu,
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
}
if avg_gen_loss is not None:
logs["step_gen_loss"] = avg_gen_loss.item()
accelerator.log(logs, step=global_step)
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# Save model checkpoint
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# Generate images
if global_step % args.validation_steps == 0:
if args.use_ema:
# Store the VQGAN parameters temporarily and load the EMA parameters to perform inference.
ema_model.store(model.parameters())
ema_model.copy_to(model.parameters())
log_validation(model, args, validation_transform, accelerator, global_step)
if args.use_ema:
# Switch back to the original VQGAN parameters.
ema_model.restore(model.parameters())
end = time.time()
# Stop training if max steps is reached
if global_step >= args.max_train_steps:
break
# End for
accelerator.wait_for_everyone()
# Save the final trained checkpoint
if accelerator.is_main_process:
model = accelerator.unwrap_model(model)
discriminator = accelerator.unwrap_model(discriminator)
if args.use_ema:
ema_model.copy_to(model.parameters())
model.save_pretrained(os.path.join(args.output_dir, "vqmodel"))
discriminator.save_pretrained(os.path.join(args.output_dir, "discriminator"))
accelerator.end_training()
if __name__ == "__main__":
main()
...@@ -41,6 +41,7 @@ class DecoderOutput(BaseOutput): ...@@ -41,6 +41,7 @@ class DecoderOutput(BaseOutput):
""" """
sample: torch.Tensor sample: torch.Tensor
commit_loss: Optional[torch.FloatTensor] = None
class Encoder(nn.Module): class Encoder(nn.Module):
......
...@@ -142,18 +142,20 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -142,18 +142,20 @@ class VQModel(ModelMixin, ConfigMixin):
) -> Union[DecoderOutput, torch.Tensor]: ) -> Union[DecoderOutput, torch.Tensor]:
# also go through quantization layer # also go through quantization layer
if not force_not_quantize: if not force_not_quantize:
quant, _, _ = self.quantize(h) quant, commit_loss, _ = self.quantize(h)
elif self.config.lookup_from_codebook: elif self.config.lookup_from_codebook:
quant = self.quantize.get_codebook_entry(h, shape) quant = self.quantize.get_codebook_entry(h, shape)
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
else: else:
quant = h quant = h
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
quant2 = self.post_quant_conv(quant) quant2 = self.post_quant_conv(quant)
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
if not return_dict: if not return_dict:
return (dec,) return dec, commit_loss
return DecoderOutput(sample=dec) return DecoderOutput(sample=dec, commit_loss=commit_loss)
def forward( def forward(
self, sample: torch.Tensor, return_dict: bool = True self, sample: torch.Tensor, return_dict: bool = True
...@@ -173,9 +175,8 @@ class VQModel(ModelMixin, ConfigMixin): ...@@ -173,9 +175,8 @@ class VQModel(ModelMixin, ConfigMixin):
""" """
h = self.encode(sample).latents h = self.encode(sample).latents
dec = self.decode(h).sample dec = self.decode(h)
if not return_dict: if not return_dict:
return (dec,) return dec.sample, dec.commit_loss
return dec
return DecoderOutput(sample=dec)
...@@ -76,6 +76,7 @@ from .import_utils import ( ...@@ -76,6 +76,7 @@ from .import_utils import (
is_safetensors_available, is_safetensors_available,
is_scipy_available, is_scipy_available,
is_tensorboard_available, is_tensorboard_available,
is_timm_available,
is_torch_available, is_torch_available,
is_torch_npu_available, is_torch_npu_available,
is_torch_version, is_torch_version,
......
...@@ -295,6 +295,19 @@ try: ...@@ -295,6 +295,19 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_torchvision_available = False _torchvision_available = False
_timm_available = importlib.util.find_spec("timm") is not None
if _timm_available:
try:
_timm_version = importlib_metadata.version("timm")
logger.info(f"Timm version {_timm_version} available.")
except importlib_metadata.PackageNotFoundError:
_timm_available = False
def is_timm_available():
return _timm_available
_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
try: try:
_bitsandbytes_version = importlib_metadata.version("bitsandbytes") _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
......
...@@ -33,6 +33,7 @@ from .import_utils import ( ...@@ -33,6 +33,7 @@ from .import_utils import (
is_onnx_available, is_onnx_available,
is_opencv_available, is_opencv_available,
is_peft_available, is_peft_available,
is_timm_available,
is_torch_available, is_torch_available,
is_torch_version, is_torch_version,
is_torchsde_available, is_torchsde_available,
...@@ -340,6 +341,13 @@ def require_peft_backend(test_case): ...@@ -340,6 +341,13 @@ def require_peft_backend(test_case):
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
def require_timm(test_case):
"""
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
"""
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
def require_peft_version_greater(peft_version): def require_peft_version_greater(peft_version):
""" """
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
......
...@@ -98,3 +98,19 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -98,3 +98,19 @@ class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143]) expected_output_slice = torch.tensor([-0.0153, -0.4044, -0.1880, -0.5161, -0.2418, -0.4072, -0.1612, -0.0633, -0.0143])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
def test_loss_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy")
model.to(torch_device).eval()
torch.manual_seed(0)
backend_manual_seed(torch_device, 0)
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image).commit_loss.cpu()
# fmt: off
expected_output = torch.tensor([0.1936])
# fmt: on
self.assertTrue(torch.allclose(output, expected_output, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment