animatediff.py 21.2 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
import csv
import datetime
import inspect
import logging
import os
import random
import re
from copy import deepcopy
from types import MethodType
from typing import Dict

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from decord import VideoReader
from diffusers import AutoencoderKL, DDIMScheduler, MotionAdapter, UNet2DConditionModel, UNetMotionModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines import AnimateDiffPipeline
from diffusers.utils import export_to_gif
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from modelscope import snapshot_download
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import RandomSampler
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from swift import LoRAConfig, Swift, get_logger, push_to_hub
from swift.aigc.utils import AnimateDiffArguments
from swift.utils import get_dist_setting, get_main, is_dist

logger = get_logger()


class AnimateDiffDataset(Dataset):

    VIDEO_ID = 'videoid'
    NAME = 'name'
    CONTENT_URL = 'contentUrl'

    def __init__(
        self,
        csv_path,
        video_folder,
        sample_size=256,
        sample_stride=4,
        sample_n_frames=16,
        dataset_sample_size=10000,
    ):
        print(f'loading annotations from {csv_path} ...')
        with open(csv_path, 'r') as csvfile:
            self.dataset = list(csv.DictReader(csvfile))
        dataset = []
        for d in tqdm(self.dataset):
            content_url = d[self.CONTENT_URL]
            file_name = content_url.split('/')[-1]
            if os.path.isfile(os.path.join(video_folder, file_name)):
                dataset.append(d)
            if dataset_sample_size is not None and len(dataset) > dataset_sample_size:
                break

        self.dataset = dataset
        self.length = len(self.dataset)
        print(f'data scale: {self.length}')

        self.video_folder = video_folder
        self.sample_stride = sample_stride
        self.sample_n_frames = sample_n_frames

        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        self.pixel_transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size[0]),
            transforms.CenterCrop(sample_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])

    def get_batch(self, idx):
        video_dict: Dict[str, str] = self.dataset[idx]
        name = video_dict[self.NAME]

        content_url = video_dict[self.CONTENT_URL]
        file_name = content_url.split('/')[-1]
        video_dir = os.path.join(self.video_folder, file_name)
        video_reader = VideoReader(video_dir)
        video_length = len(video_reader)

        clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
        start_idx = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)

        pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader
        return pixel_values, name

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, name = self.get_batch(idx)
                break

            except Exception as e:
                logger.error(f'Error loading dataset batch: {e}')
                idx = random.randint(0, self.length - 1)

        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(pixel_values=pixel_values, text=name)
        return sample


def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, duration=4):
    import imageio
    videos = rearrange(videos, 'b c t h w -> t b c h w')
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, duration=duration)


def animatediff_sft(args: AnimateDiffArguments) -> None:
    # Initialize distributed training
    if is_dist():
        _, local_rank, num_processes, _ = get_dist_setting()
        global_rank = dist.get_rank()
    else:
        local_rank = 0
        global_rank = 0
        num_processes = 1
    is_main_process = global_rank == 0

    global_seed = args.seed + global_rank
    torch.manual_seed(global_seed)

    # Logging folder
    folder_name = datetime.datetime.now().strftime('ad-%Y-%m-%dT%H-%M-%S')
    output_dir = os.path.join(args.output_dir, folder_name)

    *_, config = inspect.getargvalues(inspect.currentframe())

    if is_main_process and args.use_wandb:
        import wandb
        wandb.init(project='animatediff', name=folder_name, config=config)

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO,
    )

    # Handle the output folder creation
    if is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f'{output_dir}/samples', exist_ok=True)
        os.makedirs(f'{output_dir}/sanity_check', exist_ok=True)
        os.makedirs(f'{output_dir}/checkpoints', exist_ok=True)

    with open(args.validation_prompts_path, 'r') as f:
        validation_data = f.readlines()

    # Load scheduler, tokenizer and models.
    noise_scheduler = DDIMScheduler(
        num_train_timesteps=args.num_train_timesteps,
        beta_start=args.beta_start,
        beta_end=args.beta_end,
        beta_schedule=args.beta_schedule,
        steps_offset=args.steps_offset,
        clip_sample=args.clip_sample,
    )
    if not os.path.exists(args.model_id_or_path):
        pretrained_model_path = snapshot_download(args.model_id_or_path, revision=args.model_revision)
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder='vae')
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder='tokenizer')
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder='text_encoder')

    motion_adapter = None
    if args.motion_adapter_id_or_path is not None:
        if not os.path.exists(args.motion_adapter_id_or_path):
            args.motion_adapter_id_or_path = snapshot_download(
                args.motion_adapter_id_or_path, revision=args.motion_adapter_revision)
        motion_adapter = MotionAdapter.from_pretrained(args.motion_adapter_id_or_path)
    unet: UNetMotionModel = UNetMotionModel.from_unet2d(
        UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder='unet'),
        motion_adapter=motion_adapter,
        load_weights=True,
    )

    # Freeze vae and text_encoder
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # Set unet trainable parameters
    unet.requires_grad_(False)
    for name, param in unet.named_parameters():
        if re.fullmatch(args.trainable_modules, name):
            param.requires_grad = True

    # Preparing LoRA
    if args.sft_type == 'lora':
        if args.motion_adapter_id_or_path is None:
            raise ValueError('No AnimateDiff weight found, Please do not use LoRA.')
        lora_config = LoRAConfig(
            r=args.lora_rank,
            target_modules=args.trainable_modules,
            lora_alpha=args.lora_alpha,
            lora_dtype=args.lora_dtype,
            lora_dropout=args.lora_dropout_p)
        unet = Swift.prepare_model(unet, lora_config)
        logger.info(f'lora_config: {lora_config}')

    trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=args.learning_rate,
        weight_decay=args.weight_decay,
    )

    if is_main_process:
        print(f'trainable params number: {len(trainable_params)}')
        print(f'trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M')

    # Enable xformers
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError('xformers is not available. Make sure it is installed correctly')

    # Enable gradient checkpointing
    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Move models to GPU
    vae.to(local_rank)
    text_encoder.to(local_rank)

    # Get the training dataset
    train_dataset = AnimateDiffDataset(
        csv_path=args.csv_path,
        video_folder=args.video_folder,
        sample_size=args.sample_size,
        sample_stride=args.sample_stride,
        sample_n_frames=args.sample_n_frames,
        dataset_sample_size=args.dataset_sample_size,
    )

    if not is_dist():
        sampler = RandomSampler(train_dataset)
    else:
        sampler = DistributedSampler(
            train_dataset, num_replicas=num_processes, rank=global_rank, shuffle=True, seed=global_seed)

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=args.dataloader_num_workers,
        pin_memory=True,
        drop_last=True,
    )

    # Get the training iteration
    max_train_steps = args.num_train_epochs * len(train_dataloader)
    print(f'max_train_steps: {max_train_steps}')

    # Scheduler
    lr_scheduler = get_scheduler(
        args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=int(args.warmup_ratio * max_train_steps) // args.gradient_accumulation_steps,
        num_training_steps=max_train_steps // args.gradient_accumulation_steps,
    )

    unet.to(local_rank)
    if is_dist():
        unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)

    num_train_epochs = args.num_train_epochs

    # Train!
    total_batch_size = args.batch_size * num_processes * args.gradient_accumulation_steps

    if is_main_process:
        logging.info('***** Running training *****')
        logging.info(f'  Num examples = {len(train_dataset)}')
        logging.info(f'  Num Epochs = {num_train_epochs}')
        logging.info(f'  Instantaneous batch size per device = {args.batch_size}')
        logging.info(f'  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
        logging.info(f'  Gradient Accumulation steps = {args.gradient_accumulation_steps}')
        logging.info(f'  Total optimization steps = {max_train_steps}')
    global_step = 0
    first_epoch = 0

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
    progress_bar.set_description('Steps')

    # Support mixed-precision training
    scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None

    for epoch in range(first_epoch, num_train_epochs):
        if is_dist():
            train_dataloader.sampler.set_epoch(epoch)

        unet.train()

        for step, batch in enumerate(train_dataloader):
            if args.text_dropout_rate > 0:
                batch['text'] = [name if random.random() > args.text_dropout_rate else '' for name in batch['text']]

            # Data batch sanity check
            if epoch == first_epoch and step == 0:
                pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
                pixel_values = rearrange(pixel_values, 'b f c h w -> b c f h w')
                for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
                    pixel_value = pixel_value[None, ...]
                    file_name = '-'.join(text.replace('/',
                                                      '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'
                    save_videos_grid(pixel_value, f'{output_dir}/sanity_check/{file_name}.gif', rescale=True)

            # Convert videos to latent space
            pixel_values = batch['pixel_values'].to(local_rank)
            video_length = pixel_values.shape[1]
            with torch.no_grad():
                pixel_values = rearrange(pixel_values, 'b f c h w -> (b f) c h w')
                latents = vae.encode(pixel_values).latent_dist
                latents = latents.sample()
                latents = rearrange(latents, '(b f) c h w -> b c f h w', f=video_length)
                latents = latents * 0.18215

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]

            # Sample a random timestep for each video
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz, ), device=latents.device)
            timesteps = timesteps.long()

            # Add noise to the latents according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the text embedding for conditioning
            with torch.no_grad():
                prompt_ids = tokenizer(
                    batch['text'],
                    max_length=tokenizer.model_max_length,
                    padding='max_length',
                    truncation=True,
                    return_tensors='pt').input_ids.to(latents.device)
                encoder_hidden_states = text_encoder(prompt_ids)[0]

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == 'epsilon':
                target = noise
            elif noise_scheduler.config.prediction_type == 'v_prediction':
                raise NotImplementedError
            else:
                raise ValueError(f'Unknown prediction type {noise_scheduler.config.prediction_type}')

            # Predict the noise residual and compute loss
            # Mixed-precision training
            with torch.cuda.amp.autocast(enabled=args.mixed_precision):
                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = F.mse_loss(model_pred.float(), target.float(), reduction='mean')

            # Backpropagate
            if args.mixed_precision:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if step % args.gradient_accumulation_steps == 0:
                # Backpropagate
                if args.mixed_precision:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                    optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step()

            progress_bar.update(1)
            global_step += 1

            # Wandb logging
            if is_main_process and args.use_wandb:
                wandb.log({'train_loss': loss.item()}, step=global_step)

            # Save checkpoint
            if is_main_process and (global_step % args.save_steps == 0 or step == len(train_dataloader) - 1):
                save_path = os.path.join(output_dir, 'checkpoints')
                if step == len(train_dataloader) - 1:
                    if isinstance(unet, DDP):
                        unet.module.save_pretrained(os.path.join(save_path, 'iter-last'))
                    else:
                        unet.save_pretrained(os.path.join(save_path, 'iter-last'))
                    if args.push_to_hub:
                        push_to_hub(
                            repo_name=args.hub_model_id,
                            output_dir=os.path.join(save_path, 'iter-last'),
                            token=args.hub_token,
                            private=True,
                        )
                    logging.info(f'Saved state to {os.path.join(save_path, "iter-last")} on the last step')
                else:
                    iter_save_path = os.path.join(save_path, f'iter-{global_step}')
                    if isinstance(unet, DDP):
                        unet.module.save_pretrained(iter_save_path)
                    else:
                        unet.save_pretrained(iter_save_path)
                    if args.push_to_hub and args.push_hub_strategy == 'all_checkpoints':
                        push_to_hub(
                            repo_name=args.hub_model_id,
                            output_dir=os.path.join(save_path, f'iter-{global_step}'),
                            token=args.hub_token,
                            private=True,
                        )
                    logging.info(
                        f'Saved state to {os.path.join(save_path, f"iter-{global_step}")} (global_step: {global_step})')

            # Periodically validation
            if is_main_process and global_step % args.eval_steps == 0:

                generator = torch.Generator(device=latents.device)
                generator.manual_seed(global_seed)
                Swift.merge(unet)
                height = args.sample_size
                width = args.sample_size

                def state_dict(self,
                               *args,
                               destination=None,
                               prefix='',
                               keep_vars=False,
                               adapter_name: str = None,
                               **kwargs):
                    state_dict = self.state_dict_origin()
                    return {
                        key.replace('base_layer.', ''): value
                        for key, value in state_dict.items() if 'lora' not in key
                    }

                motion_adapter = MotionAdapter(
                    motion_num_attention_heads=args.motion_num_attention_heads,
                    motion_max_seq_length=args.motion_max_seq_length)

                module = unet if not isinstance(unet, DDP) else unet.module
                motion_adapter.mid_block.motion_modules = deepcopy(module.mid_block.motion_modules)
                motion_adapter.mid_block.motion_modules.state_dict_origin = \
                    motion_adapter.mid_block.motion_modules.state_dict
                motion_adapter.mid_block.motion_modules.state_dict = MethodType(state_dict,
                                                                                motion_adapter.mid_block.motion_modules)
                for db1, db2 in zip(motion_adapter.down_blocks, module.down_blocks):
                    db1.motion_modules = deepcopy(db2.motion_modules)
                    db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
                    db1.motion_modules.state_dict = MethodType(state_dict, db1.motion_modules)
                for db1, db2 in zip(motion_adapter.up_blocks, module.up_blocks):
                    db1.motion_modules = deepcopy(db2.motion_modules)
                    db1.motion_modules.state_dict_origin = db1.motion_modules.state_dict
                    db1.motion_modules.state_dict = MethodType(state_dict, db1.motion_modules)

                Swift.unmerge(unet)
                validation_pipeline = AnimateDiffPipeline(
                    unet=UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder='unet'),
                    vae=vae,
                    tokenizer=tokenizer,
                    motion_adapter=motion_adapter,
                    text_encoder=text_encoder,
                    scheduler=noise_scheduler,
                ).to('cuda')
                validation_pipeline.enable_vae_slicing()
                validation_pipeline.enable_model_cpu_offload()

                for idx, prompt in enumerate(validation_data):
                    output = validation_pipeline(
                        prompt=prompt,
                        negative_prompt='bad quality, worse quality',
                        num_frames=args.sample_n_frames,
                        height=height,
                        width=width,
                        guidance_scale=args.guidance_scale,
                        num_inference_steps=args.num_inference_steps,
                        generator=torch.Generator('cpu').manual_seed(global_seed),
                    )
                    frames = output.frames[0]
                    export_to_gif(frames, f'{output_dir}/samples/sample-{global_step}-{idx}.gif')
                unet.train()

            logs = {'step_loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)

            if global_step >= max_train_steps:
                break

    if is_dist():
        dist.destroy_process_group()


animatediff_main = get_main(AnimateDiffArguments, animatediff_sft)