train_unconditional.py 9.32 KB
Newer Older
anton-l's avatar
anton-l committed
1
import argparse
anton-l's avatar
anton-l committed
2
import os
anton-l's avatar
anton-l committed
3
4
5
6

import torch
import torch.nn.functional as F

7
from accelerate import Accelerator, DistributedDataParallelKwargs
8
from accelerate.logging import get_logger
anton-l's avatar
anton-l committed
9
from datasets import load_dataset
10
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
anton-l's avatar
anton-l committed
11
from diffusers.hub_utils import init_git_repo, push_to_hub
12
from diffusers.optimization import get_scheduler
anton-l's avatar
anton-l committed
13
from diffusers.training_utils import EMAModel
anton-l's avatar
anton-l committed
14
from torchvision.transforms import (
Patrick von Platen's avatar
Patrick von Platen committed
15
    CenterCrop,
anton-l's avatar
anton-l committed
16
17
    Compose,
    InterpolationMode,
anton-l's avatar
anton-l committed
18
    Normalize,
anton-l's avatar
anton-l committed
19
20
21
22
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
anton-l's avatar
anton-l committed
23
from tqdm.auto import tqdm
anton-l's avatar
anton-l committed
24
25


26
logger = get_logger(__name__)
anton-l's avatar
anton-l committed
27
28


anton-l's avatar
anton-l committed
29
def main(args):
30
    ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
31
    logging_dir = os.path.join(args.output_dir, args.logging_dir)
32
33
34
35
36
37
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
        log_with="tensorboard",
        logging_dir=logging_dir,
        kwargs_handlers=[ddp_unused_params],
    )
anton-l's avatar
anton-l committed
38
39
40
41
42
43
44
45

    model = UNetModel(
        attn_resolutions=(16,),
        ch=128,
        ch_mult=(1, 2, 4, 8),
        dropout=0.0,
        num_res_blocks=2,
        resamp_with_conv=True,
anton-l's avatar
anton-l committed
46
        resolution=args.resolution,
anton-l's avatar
anton-l committed
47
    )
48
49
50
51
52
53
54
55
    noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
anton-l's avatar
anton-l committed
56
57
58

    augmentations = Compose(
        [
anton-l's avatar
anton-l committed
59
            Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
anton-l's avatar
anton-l committed
60
            CenterCrop(args.resolution),
anton-l's avatar
anton-l committed
61
62
            RandomHorizontalFlip(),
            ToTensor(),
anton-l's avatar
anton-l committed
63
            Normalize([0.5], [0.5]),
anton-l's avatar
anton-l committed
64
65
        ]
    )
anton-l's avatar
anton-l committed
66
    dataset = load_dataset(args.dataset, split="train")
anton-l's avatar
anton-l committed
67
68
69
70
71
72

    def transforms(examples):
        images = [augmentations(image.convert("RGB")) for image in examples["image"]]
        return {"input": images}

    dataset.set_transform(transforms)
73
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
anton-l's avatar
anton-l committed
74

anton-l's avatar
anton-l committed
75
    lr_scheduler = get_scheduler(
76
        args.lr_scheduler,
anton-l's avatar
anton-l committed
77
        optimizer=optimizer,
78
        num_warmup_steps=args.lr_warmup_steps,
anton-l's avatar
anton-l committed
79
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
anton-l's avatar
anton-l committed
80
81
82
83
84
85
    )

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

86
    ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
anton-l's avatar
anton-l committed
87

anton-l's avatar
anton-l committed
88
89
90
    if args.push_to_hub:
        repo = init_git_repo(args, at_init=True)

91
92
93
94
    if accelerator.is_main_process:
        run = os.path.split(__file__)[-1].split(".")[0]
        accelerator.init_trackers(run)

anton-l's avatar
anton-l committed
95
    # Train!
anton-l's avatar
anton-l committed
96
97
    is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
    world_size = torch.distributed.get_world_size() if is_distributed else 1
98
    total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
anton-l's avatar
anton-l committed
99
100
101
102
    max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataloader.dataset)}")
    logger.info(f"  Num Epochs = {args.num_epochs}")
103
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
anton-l's avatar
anton-l committed
104
105
106
107
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_steps}")

anton-l's avatar
anton-l committed
108
    global_step = 0
anton-l's avatar
anton-l committed
109
    for epoch in range(args.num_epochs):
anton-l's avatar
anton-l committed
110
        model.train()
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["input"]
            noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
            bsz = clean_images.shape[0]
            timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()

            # add noise onto the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)

            if step % args.gradient_accumulation_steps != 0:
                with accelerator.no_sync(model):
anton-l's avatar
anton-l committed
125
                    output = model(noisy_images, timesteps)
anton-l's avatar
anton-l committed
126
                    # predict the noise residual
anton-l's avatar
anton-l committed
127
                    loss = F.mse_loss(output, noise_samples)
anton-l's avatar
anton-l committed
128
                    loss = loss / args.gradient_accumulation_steps
anton-l's avatar
anton-l committed
129
                    accelerator.backward(loss)
130
131
132
133
134
135
136
137
138
            else:
                output = model(noisy_images, timesteps)
                # predict the noise residual
                loss = F.mse_loss(output, noise_samples)
                loss = loss / args.gradient_accumulation_steps
                accelerator.backward(loss)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
139
                ema_model.step(model)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
                optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_postfix(
                loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
            )
            accelerator.log(
                {
                    "train_loss": loss.detach().item(),
                    "epoch": epoch,
                    "ema_decay": ema_model.decay,
                    "step": global_step,
                },
                step=global_step,
            )
            global_step += 1
        progress_bar.close()
anton-l's avatar
anton-l committed
156

anton-l's avatar
anton-l committed
157
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
158

anton-l's avatar
anton-l committed
159
        # Generate a sample image for visual inspection
anton-l's avatar
anton-l committed
160
        if accelerator.is_main_process:
anton-l's avatar
anton-l committed
161
            with torch.no_grad():
162
163
164
                pipeline = DDIMPipeline(
                    unet=accelerator.unwrap_model(ema_model.averaged_model),
                    noise_scheduler=noise_scheduler,
anton-l's avatar
anton-l committed
165
                )
anton-l's avatar
anton-l committed
166
167

                generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
168
                # run pipeline in inference (sample random noise and denoise)
169
                images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
anton-l's avatar
anton-l committed
170

171
172
173
            # denormalize the images and save to tensorboard
            images_processed = (images.cpu() + 1.0) * 127.5
            images_processed = images_processed.clamp(0, 255).type(torch.uint8).numpy()
anton-l's avatar
anton-l committed
174

175
            accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
anton-l's avatar
anton-l committed
176

anton-l's avatar
anton-l committed
177
178
179
180
181
            # save the model
            if args.push_to_hub:
                push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
            else:
                pipeline.save_pretrained(args.output_dir)
anton-l's avatar
anton-l committed
182
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
183

184
185
    accelerator.end_training()

anton-l's avatar
anton-l committed
186
187

if __name__ == "__main__":
anton-l's avatar
anton-l committed
188
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
anton-l's avatar
anton-l committed
189
    parser.add_argument("--local_rank", type=int, default=-1)
anton-l's avatar
anton-l committed
190
    parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
anton-l's avatar
anton-l committed
191
192
    parser.add_argument("--output_dir", type=str, default="ddpm-model")
    parser.add_argument("--overwrite_output_dir", action="store_true")
anton-l's avatar
anton-l committed
193
    parser.add_argument("--resolution", type=int, default=64)
194
195
    parser.add_argument("--train_batch_size", type=int, default=16)
    parser.add_argument("--eval_batch_size", type=int, default=16)
anton-l's avatar
anton-l committed
196
    parser.add_argument("--num_epochs", type=int, default=100)
anton-l's avatar
anton-l committed
197
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
198
199
200
201
202
203
204
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--lr_scheduler", type=str, default="cosine")
    parser.add_argument("--lr_warmup_steps", type=int, default=500)
    parser.add_argument("--adam_beta1", type=float, default=0.95)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
    parser.add_argument("--adam_epsilon", type=float, default=1e-3)
205
    parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
Patrick von Platen's avatar
Patrick von Platen committed
206
    parser.add_argument("--ema_power", type=float, default=3 / 4)
207
    parser.add_argument("--ema_max_decay", type=float, default=0.9999)
anton-l's avatar
anton-l committed
208
    parser.add_argument("--push_to_hub", action="store_true")
anton-l's avatar
anton-l committed
209
210
211
    parser.add_argument("--hub_token", type=str, default=None)
    parser.add_argument("--hub_model_id", type=str, default=None)
    parser.add_argument("--hub_private_repo", action="store_true")
212
    parser.add_argument("--logging_dir", type=str, default="logs")
anton-l's avatar
anton-l committed
213
214
215
216
217
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
218
219
220
221
222
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
anton-l's avatar
anton-l committed
223
224
225
226
227
228
229
230
    )

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    main(args)