train_unconditional.py 8.65 KB
Newer Older
anton-l's avatar
anton-l committed
1
import argparse
anton-l's avatar
anton-l committed
2
import os
anton-l's avatar
anton-l committed
3
4
5
6

import torch
import torch.nn.functional as F

7
from accelerate import Accelerator
8
from accelerate.logging import get_logger
anton-l's avatar
anton-l committed
9
from datasets import load_dataset
anton-l's avatar
anton-l committed
10
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
anton-l's avatar
anton-l committed
11
from diffusers.hub_utils import init_git_repo, push_to_hub
12
from diffusers.optimization import get_scheduler
anton-l's avatar
anton-l committed
13
from diffusers.training_utils import EMAModel
anton-l's avatar
anton-l committed
14
from torchvision.transforms import (
Patrick von Platen's avatar
Patrick von Platen committed
15
    CenterCrop,
anton-l's avatar
anton-l committed
16
17
    Compose,
    InterpolationMode,
anton-l's avatar
anton-l committed
18
    Normalize,
anton-l's avatar
anton-l committed
19
20
21
22
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
anton-l's avatar
anton-l committed
23
from tqdm.auto import tqdm
anton-l's avatar
anton-l committed
24
25


26
logger = get_logger(__name__)
anton-l's avatar
anton-l committed
27
28


anton-l's avatar
anton-l committed
29
def main(args):
30
    logging_dir = os.path.join(args.output_dir, args.logging_dir)
31
32
33
34
35
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        log_with="tensorboard",
        logging_dir=logging_dir,
    )
anton-l's avatar
anton-l committed
36

anton-l's avatar
anton-l committed
37
38
    model = UNet2DModel(
        sample_size=args.resolution,
39
40
        in_channels=3,
        out_channels=3,
anton-l's avatar
anton-l committed
41
42
43
44
45
46
47
48
49
        layers_per_block=2,
        block_out_channels=(128, 128, 256, 256, 512, 512),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
50
        ),
anton-l's avatar
anton-l committed
51
52
53
54
55
56
57
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
58
        ),
anton-l's avatar
anton-l committed
59
    )
60
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
61
62
63
64
65
66
67
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
anton-l's avatar
anton-l committed
68
69
70

    augmentations = Compose(
        [
anton-l's avatar
anton-l committed
71
            Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
anton-l's avatar
anton-l committed
72
            CenterCrop(args.resolution),
anton-l's avatar
anton-l committed
73
74
            RandomHorizontalFlip(),
            ToTensor(),
anton-l's avatar
anton-l committed
75
            Normalize([0.5], [0.5]),
anton-l's avatar
anton-l committed
76
77
        ]
    )
anton-l's avatar
anton-l committed
78
    dataset = load_dataset(args.dataset, split="train")
anton-l's avatar
anton-l committed
79
80
81
82
83
84

    def transforms(examples):
        images = [augmentations(image.convert("RGB")) for image in examples["image"]]
        return {"input": images}

    dataset.set_transform(transforms)
85
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
anton-l's avatar
anton-l committed
86

anton-l's avatar
anton-l committed
87
    lr_scheduler = get_scheduler(
88
        args.lr_scheduler,
anton-l's avatar
anton-l committed
89
        optimizer=optimizer,
90
        num_warmup_steps=args.lr_warmup_steps,
anton-l's avatar
anton-l committed
91
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
anton-l's avatar
anton-l committed
92
93
94
95
96
97
    )

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

98
    ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
anton-l's avatar
anton-l committed
99

anton-l's avatar
anton-l committed
100
101
102
    if args.push_to_hub:
        repo = init_git_repo(args, at_init=True)

103
104
105
106
    if accelerator.is_main_process:
        run = os.path.split(__file__)[-1].split(".")[0]
        accelerator.init_trackers(run)

anton-l's avatar
anton-l committed
107
    global_step = 0
anton-l's avatar
anton-l committed
108
    for epoch in range(args.num_epochs):
anton-l's avatar
anton-l committed
109
        model.train()
110
111
112
113
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["input"]
114
115
            # Sample noise that we'll add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
116
            bsz = clean_images.shape[0]
117
118
119
120
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
            ).long()
121

122
            # Add noise to the clean images according to the noise magnitude at each timestep
123
            # (this is the forward diffusion process)
124
125
126
127
128
129
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps)["sample"]
                loss = F.mse_loss(noise_pred, noise)
130
                accelerator.backward(loss)
131
132

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
133
134
                optimizer.step()
                lr_scheduler.step()
135
136
                if args.use_ema:
                    ema_model.step(model)
137
                optimizer.zero_grad()
138

139
            progress_bar.update(1)
140
141
142
143
144
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            if args.use_ema:
                logs["ema_decay"] = ema_model.decay
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
145
146
            global_step += 1
        progress_bar.close()
anton-l's avatar
anton-l committed
147

anton-l's avatar
anton-l committed
148
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
149

anton-l's avatar
anton-l committed
150
        # Generate sample images for visual inspection
anton-l's avatar
anton-l committed
151
        if accelerator.is_main_process:
anton-l's avatar
anton-l committed
152
            if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
153
154
155
                pipeline = DDPMPipeline(
                    unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
                    scheduler=noise_scheduler,
anton-l's avatar
anton-l committed
156
                )
anton-l's avatar
anton-l committed
157
158

                generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
159
                # run pipeline in inference (sample random noise and denoise)
anton-l's avatar
anton-l committed
160
                images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
anton-l's avatar
anton-l committed
161

anton-l's avatar
anton-l committed
162
163
164
165
166
                # denormalize the images and save to tensorboard
                images_processed = (images * 255).round().astype("uint8")
                accelerator.trackers[0].writer.add_images(
                    "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
                )
anton-l's avatar
anton-l committed
167

168
169
170
171
172
173
            if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
                # save the model
                if args.push_to_hub:
                    push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
                else:
                    pipeline.save_pretrained(args.output_dir)
anton-l's avatar
anton-l committed
174
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
175

176
177
    accelerator.end_training()

anton-l's avatar
anton-l committed
178
179

if __name__ == "__main__":
anton-l's avatar
anton-l committed
180
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
anton-l's avatar
anton-l committed
181
    parser.add_argument("--local_rank", type=int, default=-1)
anton-l's avatar
anton-l committed
182
    parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
183
    parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
anton-l's avatar
anton-l committed
184
    parser.add_argument("--overwrite_output_dir", action="store_true")
anton-l's avatar
anton-l committed
185
    parser.add_argument("--resolution", type=int, default=64)
186
187
    parser.add_argument("--train_batch_size", type=int, default=16)
    parser.add_argument("--eval_batch_size", type=int, default=16)
anton-l's avatar
anton-l committed
188
    parser.add_argument("--num_epochs", type=int, default=100)
anton-l's avatar
anton-l committed
189
190
    parser.add_argument("--save_images_epochs", type=int, default=10)
    parser.add_argument("--save_model_epochs", type=int, default=10)
anton-l's avatar
anton-l committed
191
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
192
193
194
195
196
197
198
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--lr_scheduler", type=str, default="cosine")
    parser.add_argument("--lr_warmup_steps", type=int, default=500)
    parser.add_argument("--adam_beta1", type=float, default=0.95)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
    parser.add_argument("--adam_epsilon", type=float, default=1e-3)
199
    parser.add_argument("--use_ema", action="store_true", default=True)
200
    parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
Patrick von Platen's avatar
Patrick von Platen committed
201
    parser.add_argument("--ema_power", type=float, default=3 / 4)
202
    parser.add_argument("--ema_max_decay", type=float, default=0.9999)
anton-l's avatar
anton-l committed
203
    parser.add_argument("--push_to_hub", action="store_true")
anton-l's avatar
anton-l committed
204
205
206
    parser.add_argument("--hub_token", type=str, default=None)
    parser.add_argument("--hub_model_id", type=str, default=None)
    parser.add_argument("--hub_private_repo", action="store_true")
207
    parser.add_argument("--logging_dir", type=str, default="logs")
anton-l's avatar
anton-l committed
208
209
210
211
212
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
213
214
215
216
217
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
anton-l's avatar
anton-l committed
218
219
220
221
222
223
224
225
    )

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    main(args)