import gc import math import os import random from typing import Any, Dict, List, Tuple, Union import torch import torch.nn.functional as F from omegaconf import ListConfig from sgm.modules import UNCONDITIONAL_CONFIG from sgm.modules.autoencoding.temporal_ae import VideoDecoder from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER from sgm.util import (default, disabled_train, get_obj_from_str, instantiate_from_config, log_txt_as_img) from torch import nn from sat import mpu from sat.helpers import print_rank0 from sat.model.finetune.lora2 import merge_linear_lora class SATVideoDiffusionEngine(nn.Module): def __init__(self, args, **kwargs): super().__init__() model_config = args.model_config # model args preprocess log_keys = model_config.get('log_keys', None) input_key = model_config.get('input_key', 'mp4') network_config = model_config.get('network_config', None) network_wrapper = model_config.get('network_wrapper', None) denoiser_config = model_config.get('denoiser_config', None) sampler_config = model_config.get('sampler_config', None) conditioner_config = model_config.get('conditioner_config', None) first_stage_config = model_config.get('first_stage_config', None) loss_fn_config = model_config.get('loss_fn_config', None) scale_factor = model_config.get('scale_factor', 1.0) latent_input = model_config.get('latent_input', False) disable_first_stage_autocast = model_config.get( 'disable_first_stage_autocast', False) no_cond_log = model_config.get('disable_first_stage_autocast', False) not_trainable_prefixes = model_config.get( 'not_trainable_prefixes', ['first_stage_model', 'conditioner']) compile_model = model_config.get('compile_model', False) en_and_decode_n_samples_a_time = model_config.get( 'en_and_decode_n_samples_a_time', None) lr_scale = model_config.get('lr_scale', None) lora_train = model_config.get('lora_train', False) self.use_pd = model_config.get('use_pd', False) # progressive distillation self.log_keys = log_keys self.input_key = input_key self.not_trainable_prefixes = not_trainable_prefixes self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time self.lr_scale = lr_scale self.lora_train = lora_train self.noised_image_input = model_config.get('noised_image_input', False) self.noised_image_all_concat = model_config.get( 'noised_image_all_concat', False) self.noised_image_dropout = model_config.get('noised_image_dropout', 0.0) if args.fp16: dtype = torch.float16 dtype_str = 'fp16' elif args.bf16: dtype = torch.bfloat16 dtype_str = 'bf16' else: dtype = torch.float32 dtype_str = 'fp32' self.dtype = dtype self.dtype_str = dtype_str network_config['params']['dtype'] = dtype_str model = instantiate_from_config(network_config) self.model = get_obj_from_str( default(network_wrapper, OPENAIUNETWRAPPER))(model, compile_model=compile_model, dtype=dtype) self.denoiser = instantiate_from_config(denoiser_config) self.sampler = instantiate_from_config( sampler_config) if sampler_config is not None else None self.conditioner = instantiate_from_config( default(conditioner_config, UNCONDITIONAL_CONFIG)) self._init_first_stage(first_stage_config) self.loss_fn = instantiate_from_config( loss_fn_config) if loss_fn_config is not None else None self.latent_input = latent_input self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast self.no_cond_log = no_cond_log self.device = args.device def disable_untrainable_params(self): total_trainable = 0 for n, p in self.named_parameters(): if p.requires_grad == False: continue flag = False for prefix in self.not_trainable_prefixes: if n.startswith(prefix) or prefix == 'all': flag = True break lora_prefix = ['matrix_A', 'matrix_B'] for prefix in lora_prefix: if prefix in n: flag = False break if flag: p.requires_grad_(False) else: total_trainable += p.numel() print_rank0('***** Total trainable parameters: ' + str(total_trainable) + ' *****') def reinit(self, parent_model=None): # reload the initial params from previous trained modules # you can also get access to other mixins through parent_model.get_mixin(). pass def merge_lora(self): for m in self.model.diffusion_model.mixins.adaln_layer.adaLN_modulations: m[1] = merge_linear_lora(m[1]) def _init_first_stage(self, config): model = instantiate_from_config(config).eval() model.train = disabled_train for param in model.parameters(): param.requires_grad = False self.first_stage_model = model def forward(self, x, batch): loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) loss_mean = loss.mean() loss_dict = {'loss': loss_mean} return loss_mean, loss_dict def add_noise_to_first_frame(self, image): sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0], )).to(self.device) sigma = torch.exp(sigma).to(image.dtype) image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] image = image + image_noise return image @torch.no_grad() def save_memory_encode_first_stage(self, x, batch): num_frames = x.shape[2] splits_x = torch.split(x, [13, 12, 12, 12], dim=2) all_out = [] with torch.autocast('cuda', enabled=False): for idx, input_x in enumerate(splits_x): if idx == len(splits_x) - 1: clear_fake_cp_cache = True else: clear_fake_cp_cache = False out = self.first_stage_model.encode( input_x.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache) all_out.append(out) z = torch.cat(all_out, dim=2) z = 1.15258426 * z return z def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) # print(f"this is iteration {self.share_cache['iteration']}", flush=True) # print(f'''{"train_size_range" in self.share_cache}''', flush=True) if 'train_size_range' in self.share_cache: train_size_range = self.share_cache.get('train_size_range') size_factor = random.uniform(*train_size_range) # broadcast the size factor from rank 0 size_factor = torch.tensor(size_factor).to(self.device) torch.distributed.broadcast(size_factor, src=0, group=mpu.get_data_parallel_group()) # print(f"size_factor: {size_factor} at rank : {torch.distributed.get_rank()}", flush=True) target_size = (int(x.shape[3] * size_factor), int(x.shape[4] * size_factor)) # print(target_size) # make sure it can be divided by 16 b, t, c, h, w = x.shape # reshape to b * t, c, h, w x = x.reshape(b * t, c, h, w) target_size = (target_size[0] // 16 * 16, target_size[1] // 16 * 16) x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=False, antialias=True) # reshape back to b, t, c, h, w x = x.reshape(b, t, c, target_size[0], target_size[1]) if self.lr_scale is not None: lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode='bilinear', align_corners=False) lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode='bilinear', align_corners=False) lr_z = self.encode_first_stage(lr_x, batch) batch['lr_input'] = lr_z x = x.permute(0, 2, 1, 3, 4).contiguous() if self.noised_image_input: image = x[:, :, 0:1] image = self.add_noise_to_first_frame(image) image = self.encode_first_stage(image, batch) b, c, t, h, w = x.shape if t == 49 and (h * w) > 480 * 720: if os.environ.get('DEBUGINFO', None) is not None: print( f'save memory encode first stage with in shape {x.shape}, {x.mean()}' ) x = self.save_memory_encode_first_stage(x, batch) else: x = self.encode_first_stage(x, batch) # x = self.encode_first_stage(x, batch) x = x.permute(0, 2, 1, 3, 4).contiguous() if 'ref_mp4' in self.share_cache: if not 'disable_ref' in self.share_cache: ref_mp4 = self.share_cache.pop('ref_mp4') ref_mp4 = ref_mp4.to(self.dtype).to(self.device) ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous() ref_x = self.encode_first_stage(ref_mp4, batch) ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous() self.share_cache['ref_x'] = ref_x if self.noised_image_input: image = image.permute(0, 2, 1, 3, 4).contiguous() if self.noised_image_all_concat: image = image.repeat(1, x.shape[1], 1, 1, 1) else: image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1) if random.random() < self.noised_image_dropout: image = torch.zeros_like(image) batch['concat_images'] = image # gc.collect() # torch.cuda.empty_cache() loss, loss_dict = self(x, batch) return loss, loss_dict def get_input(self, batch): return batch[self.input_key].to(self.dtype) @torch.no_grad() def decode_first_stage(self, z): z = 1.0 / self.scale_factor * z n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) n_rounds = math.ceil(z.shape[0] / n_samples) all_out = [] with torch.autocast('cuda', enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): if isinstance(self.first_stage_model.decoder, VideoDecoder): kwargs = { 'timesteps': len(z[n * n_samples:(n + 1) * n_samples]) } else: kwargs = {} out = self.first_stage_model.decode( z[n * n_samples:(n + 1) * n_samples], **kwargs) all_out.append(out) out = torch.cat(all_out, dim=0) return out @torch.no_grad() def encode_first_stage(self, x, batch): frame = x.shape[2] if frame > 1 and self.latent_input: x = x.permute(0, 2, 1, 3, 4).contiguous() return x * self.scale_factor # already encoded n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) n_rounds = math.ceil(x.shape[0] / n_samples) all_out = [] with torch.autocast('cuda', enabled=not self.disable_first_stage_autocast): for n in range(n_rounds): out = self.first_stage_model.encode(x[n * n_samples:(n + 1) * n_samples]) all_out.append(out) z = torch.cat(all_out, dim=0) z = self.scale_factor * z return z @torch.no_grad() def sample( self, cond: Dict, uc: Union[Dict, None] = None, batch_size: int = 16, shape: Union[None, Tuple, List] = None, prefix=None, concat_images=None, **kwargs, ): randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device) if hasattr(self, 'seeded_noise'): randn = self.seeded_noise(randn) if prefix is not None: randn = torch.cat([prefix, randn[:, prefix.shape[1]:]], dim=1) # broadcast noise mp_size = mpu.get_model_parallel_world_size() if mp_size > 1: global_rank = torch.distributed.get_rank() // mp_size src = global_rank * mp_size torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group()) scale = None scale_emb = None denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser( self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs) if 'cfg' in self.share_cache: self.sampler.guider.scale = self.share_cache['cfg'] print('overwrite cfg scale in config of stage-1') samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, num_steps=kwargs.get('num_steps', None)) samples = samples.to(self.dtype) return samples @torch.no_grad() def log_conditionings(self, batch: Dict, n: int) -> Dict: """ Defines heuristics to log different conditionings. These can be lists of strings (text-to-image), tensors, ints, ... """ image_h, image_w = batch[self.input_key].shape[3:] log = dict() for embedder in self.conditioner.embedders: if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log: x = batch[embedder.input_key][:n] if isinstance(x, torch.Tensor): if x.dim() == 1: # class-conditional, convert integer to string x = [str(x[i].item()) for i in range(x.shape[0])] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) elif x.dim() == 2: # size and crop cond and the like x = [ 'x'.join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0]) ] xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() elif isinstance(x, (List, ListConfig)): if isinstance(x[0], str): xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) else: raise NotImplementedError() else: raise NotImplementedError() log[embedder.input_key] = xc return log @torch.no_grad() def log_video( self, batch: Dict, N: int = 8, ucg_keys: List[str] = None, only_log_video_latents=False, **kwargs, ) -> Dict: conditioner_input_keys = [ e.input_key for e in self.conditioner.embedders ] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( 'Each defined ucg key for sampling must be in the provided conditioner input keys,' f'but we have {ucg_keys} vs. {conditioner_input_keys}') else: ucg_keys = conditioner_input_keys log = dict() x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} N = min(x.shape[0], N) x = x.to(self.device)[:N] if not self.latent_input: log['inputs'] = x.to(torch.float32) x = x.permute(0, 2, 1, 3, 4).contiguous() z = self.encode_first_stage(x, batch) if not only_log_video_latents: log['reconstructions'] = self.decode_first_stage(z).to( torch.float32) log['reconstructions'] = log['reconstructions'].permute( 0, 2, 1, 3, 4).contiguous() z = z.permute(0, 2, 1, 3, 4).contiguous() log.update(self.log_conditionings(batch, N)) for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) if self.noised_image_input: image = x[:, :, 0:1] image = self.add_noise_to_first_frame(image) image = self.encode_first_stage(image, batch) image = image.permute(0, 2, 1, 3, 4).contiguous() image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1) c['concat'] = image uc['concat'] = image samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples log['latents'] = latents else: samples = self.decode_first_stage(samples).to(torch.float32) samples = samples.permute(0, 2, 1, 3, 4).contiguous() log['samples'] = samples else: samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w samples = samples.permute(0, 2, 1, 3, 4).contiguous() if only_log_video_latents: latents = 1.0 / self.scale_factor * samples log['latents'] = latents else: samples = self.decode_first_stage(samples).to(torch.float32) samples = samples.permute(0, 2, 1, 3, 4).contiguous() log['samples'] = samples return log class SATUpscalerEngine(SATVideoDiffusionEngine): def shared_step(self, batch: Dict) -> Any: x = self.get_input(batch) if self.lr_scale is not None: lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode='bilinear', align_corners=False) lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode='bilinear', align_corners=False) lr_z = self.encode_first_stage(lr_x, batch) batch['lr_input'] = lr_z x = x.permute(0, 2, 1, 3, 4).contiguous() if self.noised_image_input: image = x[:, :, 0:1] image = self.add_noise_to_first_frame(image) image = self.encode_first_stage(image, batch) x = self.encode_first_stage(x, batch) x = x.permute(0, 2, 1, 3, 4).contiguous() if 'ref_mp4' in self.share_cache: ref_mp4 = self.share_cache.pop('ref_mp4') ref_mp4 = ref_mp4.to(self.dtype).to(self.device) ref_mp4 = ref_mp4.permute(0, 2, 1, 3, 4).contiguous() ref_x = self.encode_first_stage(ref_mp4, batch) ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous() self.share_cache['ref_x'] = ref_x if self.noised_image_input: image = image.permute(0, 2, 1, 3, 4).contiguous() if self.noised_image_all_concat: image = image.repeat(1, x.shape[1], 1, 1, 1) else: image = torch.concat([image, torch.zeros_like(x[:, 1:])], dim=1) if random.random() < self.noised_image_dropout: image = torch.zeros_like(image) batch['concat_images'] = image ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous() ref_x = self.first_stage_model.decoder(ref_x) ref_x = ref_x.permute(0, 2, 1, 3, 4).contiguous() loss_mean = torch.mean(((x - ref_x)**2).reshape(x.shape[0], -1), 1) loss_mean = loss_mean.mean() loss_dict = {'loss': loss_mean} return loss_mean, loss_dict def disable_untrainable_params(self): pass # def forward(self, x, batch): # loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) # loss_mean = loss.mean() # loss_dict = {"loss": loss_mean} # return loss_mean, loss_dict