train_unconditional.py 22.1 KB
Newer Older
anton-l's avatar
anton-l committed
1
import argparse
2
import inspect
3
import logging
4
import math
anton-l's avatar
anton-l committed
5
import os
6
7
from pathlib import Path
from typing import Optional
anton-l's avatar
anton-l committed
8
9
10
11

import torch
import torch.nn.functional as F

12
13
import datasets
import diffusers
14
from accelerate import Accelerator
15
from accelerate.logging import get_logger
anton-l's avatar
anton-l committed
16
from datasets import load_dataset
17
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
18
from diffusers.optimization import get_scheduler
anton-l's avatar
anton-l committed
19
from diffusers.training_utils import EMAModel
20
from diffusers.utils import check_min_version
21
from huggingface_hub import HfFolder, Repository, create_repo, whoami
anton-l's avatar
anton-l committed
22
from torchvision.transforms import (
Patrick von Platen's avatar
Patrick von Platen committed
23
    CenterCrop,
anton-l's avatar
anton-l committed
24
25
    Compose,
    InterpolationMode,
anton-l's avatar
anton-l committed
26
    Normalize,
anton-l's avatar
anton-l committed
27
28
29
30
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
anton-l's avatar
anton-l committed
31
from tqdm.auto import tqdm
anton-l's avatar
anton-l committed
32
33


34
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
35
check_min_version("0.13.0.dev0")
36

37
logger = get_logger(__name__, log_level="INFO")
anton-l's avatar
anton-l committed
38
39


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    if not isinstance(arr, torch.Tensor):
        arr = torch.from_numpy(arr)
    res = arr[timesteps].float().to(timesteps.device)
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that HF Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="ddpm-model-64",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--overwrite_output_dir", action="store_true")
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=64,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
    )
    parser.add_argument(
112
113
114
115
116
117
118
119
120
121
        "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
    )
    parser.add_argument(
        "--dataloader_num_workers",
        type=int,
        default=0,
        help=(
            "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
            " process."
        ),
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    )
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
    parser.add_argument(
        "--save_model_epochs", type=int, default=10, help="How often to save the model during training."
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-4,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="cosine",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
    )
    parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument(
        "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer."
    )
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.")
    parser.add_argument(
        "--use_ema",
        action="store_true",
        help="Whether to use Exponential Moving Average for the final model weights.",
    )
    parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
    parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.")
    parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.")
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
    parser.add_argument(
        "--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
    )
177
178
179
180
181
182
183
184
185
186
    parser.add_argument(
        "--logger",
        type=str,
        default="tensorboard",
        choices=["tensorboard", "wandb"],
        help=(
            "Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
            " for experiment tracking and logging of model metrics and model checkpoints"
        ),
    )
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
    )
208
    parser.add_argument(
209
210
211
212
        "--prediction_type",
        type=str,
        default="epsilon",
        choices=["epsilon", "sample"],
213
        help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
214
215
216
    )
    parser.add_argument("--ddpm_num_steps", type=int, default=1000)
    parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=500,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
235

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

    return args


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


anton-l's avatar
anton-l committed
257
def main(args):
258
    logging_dir = os.path.join(args.output_dir, args.logging_dir)
259

260
    accelerator = Accelerator(
261
        gradient_accumulation_steps=args.gradient_accumulation_steps,
262
        mixed_precision=args.mixed_precision,
263
        log_with=args.logger,
264
265
        logging_dir=logging_dir,
    )
anton-l's avatar
anton-l committed
266

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
            else:
                repo_name = args.hub_model_id
288
289
            create_repo(repo_name, exist_ok=True, token=args.hub_token)
            repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
290
291
292
293
294
295
296
297
298
299

            with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Initialize the model
anton-l's avatar
anton-l committed
300
301
    model = UNet2DModel(
        sample_size=args.resolution,
302
303
        in_channels=3,
        out_channels=3,
anton-l's avatar
anton-l committed
304
305
306
307
308
309
310
311
312
        layers_per_block=2,
        block_out_channels=(128, 128, 256, 256, 512, 512),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
313
        ),
anton-l's avatar
anton-l committed
314
315
316
317
318
319
320
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
321
        ),
anton-l's avatar
anton-l committed
322
    )
323

324
325
326
327
328
329
330
331
332
333
334
335
    # Create EMA for the model.
    if args.use_ema:
        ema_model = EMAModel(
            model.parameters(),
            decay=args.ema_max_decay,
            use_ema_warmup=True,
            inv_gamma=args.ema_inv_gamma,
            power=args.ema_power,
        )

    # Initialize the scheduler
    accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
336
    if accepts_prediction_type:
337
338
339
        noise_scheduler = DDPMScheduler(
            num_train_timesteps=args.ddpm_num_steps,
            beta_schedule=args.ddpm_beta_schedule,
340
            prediction_type=args.prediction_type,
341
342
343
344
        )
    else:
        noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)

345
    # Initialize the optimizer
346
347
348
349
350
351
352
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
anton-l's avatar
anton-l committed
353

354
355
    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
356

357
358
    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
359
360
361
362
363
364
365
366
367
    if args.dataset_name is not None:
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
            split="train",
        )
    else:
        dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
368
369
370
371
372
373
374
375
376
377
378
379
380
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

    # Preprocessing the datasets and DataLoaders creation.
    augmentations = Compose(
        [
            Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
            CenterCrop(args.resolution),
            RandomHorizontalFlip(),
            ToTensor(),
            Normalize([0.5], [0.5]),
        ]
    )
anton-l's avatar
anton-l committed
381
382
383
384
385

    def transforms(examples):
        images = [augmentations(image.convert("RGB")) for image in examples["image"]]
        return {"input": images}

386
387
    logger.info(f"Dataset size: {len(dataset)}")

anton-l's avatar
anton-l committed
388
    dataset.set_transform(transforms)
389
390
391
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
    )
anton-l's avatar
anton-l committed
392

393
    # Initialize the learning rate scheduler
anton-l's avatar
anton-l committed
394
    lr_scheduler = get_scheduler(
395
        args.lr_scheduler,
anton-l's avatar
anton-l committed
396
        optimizer=optimizer,
397
398
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs),
anton-l's avatar
anton-l committed
399
400
    )

401
    # Prepare everything with our `accelerator`.
anton-l's avatar
anton-l committed
402
403
404
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
405

406
407
408
    if args.use_ema:
        accelerator.register_for_checkpointing(ema_model)
        ema_model.to(accelerator.device)
anton-l's avatar
anton-l committed
409

410
411
    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
412
413
414
415
    if accelerator.is_main_process:
        run = os.path.split(__file__)[-1].split(".")[0]
        accelerator.init_trackers(run)

416
417
418
419
420
421
422
423
424
425
426
427
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    max_train_steps = args.num_epochs * num_update_steps_per_epoch

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(dataset)}")
    logger.info(f"  Num Epochs = {args.num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_train_steps}")

anton-l's avatar
anton-l committed
428
    global_step = 0
429
430
    first_epoch = 0

431
    # Potentially load in the weights and states from a previous save
432
433
434
435
436
437
438
439
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
455

456
    # Train!
457
    for epoch in range(first_epoch, args.num_epochs):
anton-l's avatar
anton-l committed
458
        model.train()
459
        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
460
461
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
462
463
464
465
466
467
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue

468
            clean_images = batch["input"]
469
470
            # Sample noise that we'll add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
471
            bsz = clean_images.shape[0]
472
473
            # Sample a random timestep for each image
            timesteps = torch.randint(
474
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
475
            ).long()
476

477
            # Add noise to the clean images according to the noise magnitude at each timestep
478
            # (this is the forward diffusion process)
479
480
481
482
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
483
484
                model_output = model(noisy_images, timesteps).sample

485
                if args.prediction_type == "epsilon":
486
                    loss = F.mse_loss(model_output, noise)  # this could have different weights!
487
                elif args.prediction_type == "sample":
488
489
490
491
492
493
494
495
                    alpha_t = _extract_into_tensor(
                        noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
                    )
                    snr_weights = alpha_t / (1 - alpha_t)
                    loss = snr_weights * F.mse_loss(
                        model_output, clean_images, reduction="none"
                    )  # use SNR weighting from distillation paper
                    loss = loss.mean()
496
497
                else:
                    raise ValueError(f"Unsupported prediction type: {args.prediction_type}")
498

499
                accelerator.backward(loss)
500

501
502
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
503
504
505
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
506

507
508
            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
509
510
                if args.use_ema:
                    ema_model.step(model.parameters())
511
512
513
                progress_bar.update(1)
                global_step += 1

514
515
516
517
518
519
                if global_step % args.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

520
521
522
523
524
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            if args.use_ema:
                logs["ema_decay"] = ema_model.decay
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
525
        progress_bar.close()
anton-l's avatar
anton-l committed
526

anton-l's avatar
anton-l committed
527
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
528

anton-l's avatar
anton-l committed
529
        # Generate sample images for visual inspection
anton-l's avatar
anton-l committed
530
        if accelerator.is_main_process:
anton-l's avatar
anton-l committed
531
            if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
532
                unet = accelerator.unwrap_model(model)
533
534
                if args.use_ema:
                    ema_model.copy_to(unet.parameters())
535
                pipeline = DDPMPipeline(
536
                    unet=unet,
537
                    scheduler=noise_scheduler,
anton-l's avatar
anton-l committed
538
                )
anton-l's avatar
anton-l committed
539

540
                generator = torch.Generator(device=pipeline.device).manual_seed(0)
anton-l's avatar
anton-l committed
541
                # run pipeline in inference (sample random noise and denoise)
542
543
544
545
546
                images = pipeline(
                    generator=generator,
                    batch_size=args.eval_batch_size,
                    output_type="numpy",
                ).images
anton-l's avatar
anton-l committed
547

anton-l's avatar
anton-l committed
548
549
                # denormalize the images and save to tensorboard
                images_processed = (images * 255).round().astype("uint8")
550
551
552
553
554

                if args.logger == "tensorboard":
                    accelerator.get_tracker("tensorboard").add_images(
                        "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
                    )
anton-l's avatar
anton-l committed
555

556
557
            if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
                # save the model
558
                pipeline.save_pretrained(args.output_dir)
559
                if args.push_to_hub:
560
                    repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
anton-l's avatar
anton-l committed
561

562
563
    accelerator.end_training()

anton-l's avatar
anton-l committed
564
565

if __name__ == "__main__":
566
    args = parse_args()
anton-l's avatar
anton-l committed
567
    main(args)