"examples/images/dreambooth/requirements_colossalai.txt" did not exist on "1cf6d92d7c93a26e29cadeb71bb34ee96b149a28"
train_dreambooth_colossalai.py 27.9 KB
Newer Older
1
2
3
4
5
6
import argparse
import hashlib
import math
import os
from pathlib import Path
from typing import Optional
Maruyama_Aya's avatar
Maruyama_Aya committed
7
import shutil
8
9
10
11

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
12
13
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
14
from huggingface_hub import HfFolder, Repository, create_repo, whoami
15
from PIL import Image
16
from torch.utils.data import Dataset
17
18
19
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
20
21
22
23
24

import colossalai
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
Maruyama_Aya's avatar
Maruyama_Aya committed
25
from colossalai.nn.optimizer import HybridAdam
26
from colossalai.utils import get_current_device
Maruyama_Aya's avatar
Maruyama_Aya committed
27
from colossalai.zero import ColoInitContext
28
from colossalai.zero.gemini import get_static_torch_model
Maruyama_Aya's avatar
Maruyama_Aya committed
29
30
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

disable_existing_loggers()
logger = get_dist_logger()


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
Maruyama_Aya's avatar
Maruyama_Aya committed
65
66
67
68
69
70
71
    parser.add_argument(
        "--externel_unet_path",
        type=str,
        default=None,
        required=False,
        help="Path to the externel unet model.",
    )
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--instance_data_dir",
        type=str,
        default=None,
        required=True,
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--class_data_dir",
        type=str,
        default=None,
        required=False,
        help="A folder containing the training data of class images.",
    )
    parser.add_argument(
        "--instance_prompt",
        type=str,
        default="a photo of sks dog",
        required=False,
        help="The prompt with identifier specifying the instance",
    )
    parser.add_argument(
        "--class_prompt",
        type=str,
        default=None,
        help="The prompt to specify images in the same class as provided instance images.",
    )
    parser.add_argument(
        "--with_prior_preservation",
        default=False,
        action="store_true",
        help="Flag to add prior preservation loss.",
    )
    parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
    parser.add_argument(
        "--num_class_images",
        type=int,
        default=100,
123
124
        help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
              " class_data_dir, additional images will be sampled with class_prompt."),
125
126
127
128
129
130
131
132
133
134
135
136
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="text-inversion-model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
137
138
        help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
              " resolution"),
139
140
141
142
    )
    parser.add_argument(
        "--placement",
        type=str,
143
        default="cpu",
144
145
        help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
    )
146
147
148
149
150
151
152
    parser.add_argument(
        "--center_crop",
        default=False,
        action="store_true",
        help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
              " cropped. The images will be resized to the resolution first before cropping."),
    )
153
154
155
156
157
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size (per device) for the training dataloader.")
    parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=None,
        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
    )
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-6,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=False,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
187
188
        help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
              ' "constant", "constant_with_warmup"]'),
189
    )
190
191
192
193
194
195
196
    parser.add_argument("--lr_warmup_steps",
                        type=int,
                        default=500,
                        help="Number of steps for the warmup in the lr scheduler.")
    parser.add_argument("--use_8bit_adam",
                        action="store_true",
                        help="Whether or not to use 8-bit Adam from bitsandbytes.")
197
198
199
200

    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
Maruyama_Aya's avatar
Maruyama_Aya committed
201
    parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
202
203
204
205
206
207
    parser.add_argument(
        "--hub_model_id",
        type=str,
        default=None,
        help="The name of the repository to keep in sync with the local `output_dir`.",
    )
Maruyama_Aya's avatar
Maruyama_Aya committed
208
209
210
211
212
213
    parser.add_argument('-p',
                        '--plugin',
                        type=str,
                        default='torch_ddp',
                        choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
                        help="plugin to use")
214
215
216
217
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
218
219
        help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
              " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
220
221
222
223
224
225
226
227
228
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
229
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()

    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
    if env_local_rank != -1 and env_local_rank != args.local_rank:
        args.local_rank = env_local_rank

    if args.with_prior_preservation:
        if args.class_data_dir is None:
            raise ValueError("You must specify a data directory for class images.")
        if args.class_prompt is None:
            raise ValueError("You must specify prompt for class images.")
    else:
        if args.class_data_dir is not None:
            logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
        if args.class_prompt is not None:
            logger.warning("You need not use --class_prompt without --with_prior_preservation.")

    return args


class DreamBoothDataset(Dataset):
    """
    A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
    It pre-processes the images and the tokenizes prompts.
    """

    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
Maruyama_Aya's avatar
Maruyama_Aya committed
271
        test=False,
272
273
274
275
276
277
278
279
280
281
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
Maruyama_Aya's avatar
Maruyama_Aya committed
282
283
        if test:
            self.instance_images_path = self.instance_images_path[:10]
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

298
299
300
301
302
303
        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids

        return example


class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
    if token is None:
        token = HfFolder.get_token()
    if organization is None:
        username = whoami(token)["name"]
        return f"{username}/{model_id}"
    else:
        return f"{organization}/{model_id}"


def main(args):
364
365
366
367
    if args.seed is None:
        colossalai.launch_from_torch(config={})
    else:
        colossalai.launch_from_torch(config={}, seed=args.seed)
368

jiaruifang's avatar
jiaruifang committed
369
370
371
    local_rank = gpc.get_local_rank(ParallelMode.DATA)
    world_size = gpc.get_world_size(ParallelMode.DATA)

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    if args.with_prior_preservation:
        class_images_dir = Path(args.class_data_dir)
        if not class_images_dir.exists():
            class_images_dir.mkdir(parents=True)
        cur_class_images = len(list(class_images_dir.iterdir()))

        if cur_class_images < args.num_class_images:
            torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
            pipeline = DiffusionPipeline.from_pretrained(
                args.pretrained_model_name_or_path,
                torch_dtype=torch_dtype,
                safety_checker=None,
                revision=args.revision,
            )
            pipeline.set_progress_bar_config(disable=True)

            num_new_images = args.num_class_images - cur_class_images
            logger.info(f"Number of class images to sample: {num_new_images}.")

            sample_dataset = PromptDataset(args.class_prompt, num_new_images)
            sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

            pipeline.to(get_current_device())

396
            for example in tqdm(
397
398
                    sample_dataloader,
                    desc="Generating class images",
jiaruifang's avatar
jiaruifang committed
399
                    disable=not local_rank == 0,
400
            ):
401
402
403
                images = pipeline(example["prompt"]).images

                for i, image in enumerate(images):
Maruyama_Aya's avatar
Maruyama_Aya committed
404
                    hash_image = hashlib.sha256(image.tobytes()).hexdigest()
405
406
407
408
409
410
                    image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                    image.save(image_filename)

            del pipeline

    # Handle the repository creation
jiaruifang's avatar
jiaruifang committed
411
    if local_rank == 0:
412
413
414
415
416
        if args.push_to_hub:
            if args.hub_model_id is None:
                repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
            else:
                repo_name = args.hub_model_id
417
418
            create_repo(repo_name, exist_ok=True, token=args.hub_token)
            repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

            with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
                if "step_*" not in gitignore:
                    gitignore.write("step_*\n")
                if "epoch_*" not in gitignore:
                    gitignore.write("epoch_*\n")
        elif args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

    # Load the tokenizer
    if args.tokenizer_name:
        logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
        tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name,
            revision=args.revision,
            use_fast=False,
        )
    elif args.pretrained_model_name_or_path:
        logger.info("Loading tokenizer from pretrained model", ranks=[0])
        tokenizer = AutoTokenizer.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="tokenizer",
            revision=args.revision,
            use_fast=False,
        )
        # import correct text encoder class
    text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)

    # Load models and create wrapper for stable diffusion

    logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])

451
452
453
454
455
    text_encoder = text_encoder_cls.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=args.revision,
    )
456
457

    logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
458
459
460
461
462
    vae = AutoencoderKL.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="vae",
        revision=args.revision,
    )
463

Maruyama_Aya's avatar
Maruyama_Aya committed
464
465
466

    if args.externel_unet_path is None:
        logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
467
        unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
Maruyama_Aya's avatar
Maruyama_Aya committed
468
469
470
471
472
473
474
475
                                                subfolder="unet",
                                                revision=args.revision,
                                                low_cpu_mem_usage=False)
    else:
        logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
        unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
                                                revision=args.revision,
                                                low_cpu_mem_usage=False)
476
477
478
479
480
481
482
483

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    if args.scale_lr:
484
        args.learning_rate = args.learning_rate * args.train_batch_size * world_size
485

Maruyama_Aya's avatar
Maruyama_Aya committed
486
487
488
489
490
491
492
493
    # Use Booster API to use Gemini/Zero with ColossalAI

    booster_kwargs = {}
    if args.plugin == 'torch_ddp_fp16':
        booster_kwargs['mixed_precision'] = 'fp16'
    if args.plugin.startswith('torch_ddp'):
        plugin = TorchDDPPlugin()
    elif args.plugin == 'gemini':
Maruyama_Aya's avatar
Maruyama_Aya committed
494
        plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
Maruyama_Aya's avatar
Maruyama_Aya committed
495
496
497
498
    elif args.plugin == 'low_level_zero':
        plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)

    booster = Booster(plugin=plugin, **booster_kwargs)
499
500

    # config optimizer for colossalai zero
Maruyama_Aya's avatar
Maruyama_Aya committed
501
    optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
502
503
504
505
506

    # load noise_scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

    # prepare dataset
507
    logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
508
509
510
511
512
513
514
515
    train_dataset = DreamBoothDataset(
        instance_data_root=args.instance_data_dir,
        instance_prompt=args.instance_prompt,
        class_data_root=args.class_data_dir if args.with_prior_preservation else None,
        class_prompt=args.class_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
Maruyama_Aya's avatar
Maruyama_Aya committed
516
        test=args.test_run
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    )

    def collate_fn(examples):
        input_ids = [example["instance_prompt_ids"] for example in examples]
        pixel_values = [example["instance_images"] for example in examples]

        # Concat class and instance examples for prior preservation.
        # We do this to avoid doing two forward passes.
        if args.with_prior_preservation:
            input_ids += [example["class_prompt_ids"] for example in examples]
            pixel_values += [example["class_images"] for example in examples]

        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        input_ids = tokenizer.pad(
533
534
535
            {
                "input_ids": input_ids
            },
536
537
538
539
540
541
542
543
544
545
546
            padding="max_length",
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        batch = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
        }
        return batch

547
548
549
550
551
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.train_batch_size,
                                                   shuffle=True,
                                                   collate_fn=collate_fn,
                                                   num_workers=1)
552
553
554

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
555
    num_update_steps_per_epoch = math.ceil(len(train_dataloader))
556
557
558
559
560
561
562
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
563
564
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=args.max_train_steps,
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    )
    weight_dtype = torch.float32
    if args.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif args.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move text_encode and vae to gpu.
    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    vae.to(get_current_device(), dtype=weight_dtype)
    text_encoder.to(get_current_device(), dtype=weight_dtype)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
579
    num_update_steps_per_epoch = math.ceil(len(train_dataloader))
580
581
582
583
584
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

Maruyama_Aya's avatar
Maruyama_Aya committed
585
586
    unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)

587
    # Train!
588
    total_batch_size = args.train_batch_size * world_size
589
590
591
592
593
594
595
596
597
598

    logger.info("***** Running training *****", ranks=[0])
    logger.info(f"  Num examples = {len(train_dataset)}", ranks=[0])
    logger.info(f"  Num batches each epoch = {len(train_dataloader)}", ranks=[0])
    logger.info(f"  Num Epochs = {args.num_train_epochs}", ranks=[0])
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
    logger.info(f"  Total optimization steps = {args.max_train_steps}", ranks=[0])

    # Only show the progress bar once on each machine.
jiaruifang's avatar
jiaruifang committed
599
    progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
600
601
602
603
604
605
606
    progress_bar.set_description("Steps")
    global_step = 0

    torch.cuda.synchronize()
    for epoch in range(args.num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
607
            torch.cuda.reset_peak_memory_stats()
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
            # Move batch to gpu
            for key, value in batch.items():
                batch[key] = value.to(get_current_device(), non_blocking=True)

            # Convert images to latent space
            optimizer.zero_grad()

            latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
            latents = latents * 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]

            # Predict the noise residual
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            if args.with_prior_preservation:
                # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
                model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
                target, target_prior = torch.chunk(target, 2, dim=0)

                # Compute instance loss
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()

                # Compute prior loss
                prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

                # Add the prior loss to the instance loss.
                loss = loss + args.prior_loss_weight * prior_loss
            else:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            optimizer.backward(loss)

            optimizer.step()
            lr_scheduler.step()
663
            logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
664
665
666
667
668
            # Checks if the accelerator has performed an optimization step behind the scenes
            progress_bar.update(1)
            global_step += 1
            logs = {
                "loss": loss.detach().item(),
669
                "lr": optimizer.param_groups[0]["lr"],
670
            }    # lr_scheduler.get_last_lr()[0]}
671
672
673
674
            progress_bar.set_postfix(**logs)

            if global_step % args.save_steps == 0:
                torch.cuda.synchronize()
Maruyama_Aya's avatar
Maruyama_Aya committed
675
676
                save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
jiaruifang's avatar
jiaruifang committed
677
                if local_rank == 0:
Maruyama_Aya's avatar
Maruyama_Aya committed
678
679
                    if not os.path.exists(os.path.join(save_path, "config.json")):
                        shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
680
681
682
683
                    logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
            if global_step >= args.max_train_steps:
                break
    torch.cuda.synchronize()
684

Maruyama_Aya's avatar
Maruyama_Aya committed
685
686
    booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
    logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
jiaruifang's avatar
jiaruifang committed
687
    if local_rank == 0:
Maruyama_Aya's avatar
Maruyama_Aya committed
688
689
        if not os.path.exists(os.path.join(args.output_dir, "config.json")):
            shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
690
691
692
693
694
695
        if args.push_to_hub:
            repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)

if __name__ == "__main__":
    args = parse_args()
    main(args)