train_unconditional.py 9.65 KB
Newer Older
anton-l's avatar
anton-l committed
1
import argparse
2
import math
anton-l's avatar
anton-l committed
3
import os
anton-l's avatar
anton-l committed
4
5
6
7

import torch
import torch.nn.functional as F

8
from accelerate import Accelerator
9
from accelerate.logging import get_logger
anton-l's avatar
anton-l committed
10
from datasets import load_dataset
anton-l's avatar
anton-l committed
11
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
12
from diffusers.hub_utils import init_git_repo
13
from diffusers.optimization import get_scheduler
anton-l's avatar
anton-l committed
14
from diffusers.training_utils import EMAModel
anton-l's avatar
anton-l committed
15
from torchvision.transforms import (
Patrick von Platen's avatar
Patrick von Platen committed
16
    CenterCrop,
anton-l's avatar
anton-l committed
17
18
    Compose,
    InterpolationMode,
anton-l's avatar
anton-l committed
19
    Normalize,
anton-l's avatar
anton-l committed
20
21
22
23
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
anton-l's avatar
anton-l committed
24
from tqdm.auto import tqdm
anton-l's avatar
anton-l committed
25
26


27
logger = get_logger(__name__)
anton-l's avatar
anton-l committed
28
29


anton-l's avatar
anton-l committed
30
def main(args):
31
    logging_dir = os.path.join(args.output_dir, args.logging_dir)
32
    accelerator = Accelerator(
33
        gradient_accumulation_steps=args.gradient_accumulation_steps,
34
35
36
37
        mixed_precision=args.mixed_precision,
        log_with="tensorboard",
        logging_dir=logging_dir,
    )
anton-l's avatar
anton-l committed
38

anton-l's avatar
anton-l committed
39
40
    model = UNet2DModel(
        sample_size=args.resolution,
41
42
        in_channels=3,
        out_channels=3,
anton-l's avatar
anton-l committed
43
44
45
46
47
48
49
50
51
        layers_per_block=2,
        block_out_channels=(128, 128, 256, 256, 512, 512),
        down_block_types=(
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",
            "DownBlock2D",
52
        ),
anton-l's avatar
anton-l committed
53
54
55
56
57
58
59
        up_block_types=(
            "UpBlock2D",
            "AttnUpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
            "UpBlock2D",
60
        ),
anton-l's avatar
anton-l committed
61
    )
62
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
63
64
65
66
67
68
69
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )
anton-l's avatar
anton-l committed
70
71
72

    augmentations = Compose(
        [
anton-l's avatar
anton-l committed
73
            Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
anton-l's avatar
anton-l committed
74
            CenterCrop(args.resolution),
anton-l's avatar
anton-l committed
75
76
            RandomHorizontalFlip(),
            ToTensor(),
anton-l's avatar
anton-l committed
77
            Normalize([0.5], [0.5]),
anton-l's avatar
anton-l committed
78
79
        ]
    )
80
81
82
83
84
85
86
87
88
89

    if args.dataset_name is not None:
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
            split="train",
        )
    else:
        dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
anton-l's avatar
anton-l committed
90
91
92
93
94
95

    def transforms(examples):
        images = [augmentations(image.convert("RGB")) for image in examples["image"]]
        return {"input": images}

    dataset.set_transform(transforms)
96
    train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
anton-l's avatar
anton-l committed
97

anton-l's avatar
anton-l committed
98
    lr_scheduler = get_scheduler(
99
        args.lr_scheduler,
anton-l's avatar
anton-l committed
100
        optimizer=optimizer,
101
        num_warmup_steps=args.lr_warmup_steps,
anton-l's avatar
anton-l committed
102
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
anton-l's avatar
anton-l committed
103
104
105
106
107
108
    )

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )

109
110
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

111
    ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
anton-l's avatar
anton-l committed
112

anton-l's avatar
anton-l committed
113
114
115
    if args.push_to_hub:
        repo = init_git_repo(args, at_init=True)

116
117
118
119
    if accelerator.is_main_process:
        run = os.path.split(__file__)[-1].split(".")[0]
        accelerator.init_trackers(run)

anton-l's avatar
anton-l committed
120
    global_step = 0
anton-l's avatar
anton-l committed
121
    for epoch in range(args.num_epochs):
anton-l's avatar
anton-l committed
122
        model.train()
123
        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
124
125
126
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            clean_images = batch["input"]
127
128
            # Sample noise that we'll add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
129
            bsz = clean_images.shape[0]
130
131
            # Sample a random timestep for each image
            timesteps = torch.randint(
132
                0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
133
            ).long()
134

135
            # Add noise to the clean images according to the noise magnitude at each timestep
136
            # (this is the forward diffusion process)
137
138
139
140
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
141
                noise_pred = model(noisy_images, timesteps).sample
142
                loss = F.mse_loss(noise_pred, noise)
143
                accelerator.backward(loss)
144

145
146
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
147
148
                optimizer.step()
                lr_scheduler.step()
149
150
                if args.use_ema:
                    ema_model.step(model)
151
                optimizer.zero_grad()
152

153
154
155
156
157
            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

158
159
160
161
162
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            if args.use_ema:
                logs["ema_decay"] = ema_model.decay
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
163
        progress_bar.close()
anton-l's avatar
anton-l committed
164

anton-l's avatar
anton-l committed
165
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
166

anton-l's avatar
anton-l committed
167
        # Generate sample images for visual inspection
anton-l's avatar
anton-l committed
168
        if accelerator.is_main_process:
anton-l's avatar
anton-l committed
169
            if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
170
171
172
                pipeline = DDPMPipeline(
                    unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
                    scheduler=noise_scheduler,
anton-l's avatar
anton-l committed
173
                )
anton-l's avatar
anton-l committed
174
175

                generator = torch.manual_seed(0)
anton-l's avatar
anton-l committed
176
                # run pipeline in inference (sample random noise and denoise)
177
                images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
anton-l's avatar
anton-l committed
178

anton-l's avatar
anton-l committed
179
180
181
182
183
                # denormalize the images and save to tensorboard
                images_processed = (images * 255).round().astype("uint8")
                accelerator.trackers[0].writer.add_images(
                    "test_samples", images_processed.transpose(0, 3, 1, 2), epoch
                )
anton-l's avatar
anton-l committed
184

185
186
187
            if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
                # save the model
                if args.push_to_hub:
188
                    repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
189
190
                else:
                    pipeline.save_pretrained(args.output_dir)
anton-l's avatar
anton-l committed
191
        accelerator.wait_for_everyone()
anton-l's avatar
anton-l committed
192

193
194
    accelerator.end_training()

anton-l's avatar
anton-l committed
195
196

if __name__ == "__main__":
anton-l's avatar
anton-l committed
197
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
anton-l's avatar
anton-l committed
198
    parser.add_argument("--local_rank", type=int, default=-1)
199
200
201
202
    parser.add_argument("--dataset_name", type=str, default=None)
    parser.add_argument("--dataset_config_name", type=str, default=None)
    parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.")
    parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
anton-l's avatar
anton-l committed
203
    parser.add_argument("--overwrite_output_dir", action="store_true")
204
    parser.add_argument("--cache_dir", type=str, default=None)
anton-l's avatar
anton-l committed
205
    parser.add_argument("--resolution", type=int, default=64)
206
207
    parser.add_argument("--train_batch_size", type=int, default=16)
    parser.add_argument("--eval_batch_size", type=int, default=16)
anton-l's avatar
anton-l committed
208
    parser.add_argument("--num_epochs", type=int, default=100)
anton-l's avatar
anton-l committed
209
210
    parser.add_argument("--save_images_epochs", type=int, default=10)
    parser.add_argument("--save_model_epochs", type=int, default=10)
anton-l's avatar
anton-l committed
211
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
212
213
214
215
216
217
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--lr_scheduler", type=str, default="cosine")
    parser.add_argument("--lr_warmup_steps", type=int, default=500)
    parser.add_argument("--adam_beta1", type=float, default=0.95)
    parser.add_argument("--adam_beta2", type=float, default=0.999)
    parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
anton-l's avatar
anton-l committed
218
    parser.add_argument("--adam_epsilon", type=float, default=1e-08)
219
    parser.add_argument("--use_ema", action="store_true", default=True)
220
    parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
Patrick von Platen's avatar
Patrick von Platen committed
221
    parser.add_argument("--ema_power", type=float, default=3 / 4)
222
    parser.add_argument("--ema_max_decay", type=float, default=0.9999)
anton-l's avatar
anton-l committed
223
    parser.add_argument("--push_to_hub", action="store_true")
anton-l's avatar
anton-l committed
224
225
226
    parser.add_argument("--hub_token", type=str, default=None)
    parser.add_argument("--hub_model_id", type=str, default=None)
    parser.add_argument("--hub_private_repo", action="store_true")
227
    parser.add_argument("--logging_dir", type=str, default="logs")
anton-l's avatar
anton-l committed
228
229
230
231
232
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="no",
        choices=["no", "fp16", "bf16"],
233
234
235
236
237
        help=(
            "Whether to use mixed precision. Choose"
            "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
            "and an Nvidia Ampere GPU."
        ),
anton-l's avatar
anton-l committed
238
239
240
241
242
243
244
    )

    args = parser.parse_args()
    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

245
246
247
    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError("You must specify either a dataset name from the hub or a train data directory.")

anton-l's avatar
anton-l committed
248
    main(args)