Commit 1f5da520 authored by yangzhong's avatar yangzhong
Browse files

git init

parents
Pipeline #3144 failed with stages
in 0 seconds
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union, Optional
from omegaconf import ListConfig, OmegaConf
from copy import deepcopy
import torch.nn.functional as F
from sat.helpers import print_rank0
import torch
from torch import nn
from sgm.modules import UNCONDITIONAL_CONFIG
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
from sgm.util import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
log_txt_as_img,
)
import gc
from sat import mpu
import random
class SATVideoDiffusionEngine(nn.Module):
def __init__(self, args, **kwargs):
super().__init__()
model_config = args.model_config
# model args preprocess
log_keys = model_config.get("log_keys", None)
input_key = model_config.get("input_key", "mp4")
network_config = model_config.get("network_config", None)
network_wrapper = model_config.get("network_wrapper", None)
denoiser_config = model_config.get("denoiser_config", None)
sampler_config = model_config.get("sampler_config", None)
conditioner_config = model_config.get("conditioner_config", None)
first_stage_config = model_config.get("first_stage_config", None)
loss_fn_config = model_config.get("loss_fn_config", None)
scale_factor = model_config.get("scale_factor", 1.0)
latent_input = model_config.get("latent_input", False)
disable_first_stage_autocast = model_config.get("disable_first_stage_autocast", False)
no_cond_log = model_config.get("disable_first_stage_autocast", False)
not_trainable_prefixes = model_config.get("not_trainable_prefixes", ["first_stage_model", "conditioner"])
compile_model = model_config.get("compile_model", False)
en_and_decode_n_samples_a_time = model_config.get("en_and_decode_n_samples_a_time", None)
lr_scale = model_config.get("lr_scale", None)
lora_train = model_config.get("lora_train", False)
self.use_pd = model_config.get("use_pd", False) # progressive distillation
self.log_keys = log_keys
self.input_key = input_key
self.not_trainable_prefixes = not_trainable_prefixes
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.lr_scale = lr_scale
self.lora_train = lora_train
self.noised_image_input = model_config.get("noised_image_input", False)
self.noised_image_all_concat = model_config.get("noised_image_all_concat", False)
self.noised_image_dropout = model_config.get("noised_image_dropout", 0.0)
if args.fp16:
dtype = torch.float16
dtype_str = "fp16"
elif args.bf16:
dtype = torch.bfloat16
dtype_str = "bf16"
else:
dtype = torch.float32
dtype_str = "fp32"
self.dtype = dtype
self.dtype_str = dtype_str
network_config["params"]["dtype"] = dtype_str
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model, dtype=dtype
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = instantiate_from_config(sampler_config) if sampler_config is not None else None
self.conditioner = instantiate_from_config(default(conditioner_config, UNCONDITIONAL_CONFIG))
self._init_first_stage(first_stage_config)
self.loss_fn = instantiate_from_config(loss_fn_config) if loss_fn_config is not None else None
self.latent_input = latent_input
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
self.device = args.device
def disable_untrainable_params(self):
total_trainable = 0
for n, p in self.named_parameters():
if p.requires_grad == False:
continue
flag = False
for prefix in self.not_trainable_prefixes:
if n.startswith(prefix) or prefix == "all":
flag = True
break
lora_prefix = ["matrix_A", "matrix_B", 'final_layer', 'proj_sr', 'local']
for prefix in lora_prefix:
if prefix in n:
flag = False
break
if flag:
p.requires_grad_(False)
else:
print(n)
total_trainable += p.numel()
print_rank0("***** Total trainable parameters: " + str(total_trainable / 1000000) + "M *****")
def reinit(self, parent_model=None):
# reload the initial params from previous trained modules
# you can also get access to other mixins through parent_model.get_mixin().
pass
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def forward(self, x, hq_video, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch, hq_video, self.decode_first_stage)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
if self.lr_scale is not None:
lr_x = F.interpolate(x, scale_factor=1 / self.lr_scale, mode="bilinear", align_corners=False)
lr_x = F.interpolate(lr_x, scale_factor=self.lr_scale, mode="bilinear", align_corners=False)
lr_z = self.encode_first_stage(lr_x, batch)
batch["lr_input"] = lr_z
x = x.permute(0, 2, 1, 3, 4).contiguous() # (B, T, C, H, W) -> (B, C, T, H, W)
hq_video = x # (B, C, T, H, W)
x = self.encode_first_stage(x, batch)
x = x.permute(0, 2, 1, 3, 4).contiguous() # (B, C, T, H, W) -> (B, T, C, H, W)
if 'lq' in batch.keys():
# print('LQ is NOT None')
lq = batch['lq'].to(self.dtype)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = self.encode_first_stage(lq, batch)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
batch['lq'] = lq
# Uncomment for t2v training,
# batch['lq'] = None
gc.collect()
torch.cuda.empty_cache()
loss, loss_dict = self(x, hq_video, batch)
return loss, loss_dict
def get_input(self, batch):
return batch[self.input_key].to(self.dtype)
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
use_cp = False
out = self.first_stage_model.decode(z[n * n_samples : (n + 1) * n_samples], **kwargs)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x, batch=None):
frame = x.shape[2]
if frame > 1 and self.latent_input:
x = x.permute(0, 2, 1, 3, 4).contiguous()
return x * self.scale_factor # already encoded
use_cp = False
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
for n in range(n_rounds):
out = self.first_stage_model.encode(x[n * n_samples : (n + 1) * n_samples])
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
prefix=None,
concat_images=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
if hasattr(self, "seeded_noise"):
randn = self.seeded_noise(randn)
if prefix is not None:
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
if mp_size > 1:
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
scale = None
scale_emb = None
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def sample_sr(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
lq=None,
prefix=None,
concat_images=None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
if hasattr(self, "seeded_noise"):
randn = self.seeded_noise(randn)
if prefix is not None:
randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
# broadcast noise
mp_size = mpu.get_model_parallel_world_size()
if mp_size > 1:
global_rank = torch.distributed.get_rank() // mp_size
src = global_rank * mp_size
torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
scale = None
scale_emb = None
denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
)
# add lq condition (new)
lq = lq.to(randn.device, self.dtype)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = self.encode_first_stage(lq)
lq = lq.permute(0, 2, 1, 3, 4).contiguous()
lq = torch.cat((lq, lq), dim=0) # for CFG inference
# For T2V
# lq = None
# print('randn shape:', randn.shape) # torch.Size([1, 8, 16, 60, 90])
# print('lq shape:', lq.shape) # torch.Size([1, 8, 16, 60, 90])
samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb, lq=lq)
samples = samples.to(self.dtype)
return samples
@torch.no_grad()
def log_conditionings(self, batch: Dict, n: int) -> Dict:
"""
Defines heuristics to log different conditionings.
These can be lists of strings (text-to-image), tensors, ints, ...
"""
image_h, image_w = batch[self.input_key].shape[3:]
log = dict()
for embedder in self.conditioner.embedders:
if ((self.log_keys is None) or (embedder.input_key in self.log_keys)) and not self.no_cond_log:
x = batch[embedder.input_key][:n]
if isinstance(x, torch.Tensor):
if x.dim() == 1:
# class-conditional, convert integer to string
x = [str(x[i].item()) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
elif x.dim() == 2:
# size and crop cond and the like
x = ["x".join([str(xx) for xx in x[i].tolist()]) for i in range(x.shape[0])]
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
elif isinstance(x, (List, ListConfig)):
if isinstance(x[0], str):
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
log[embedder.input_key] = xc
return log
@torch.no_grad()
def log_video(
self,
batch: Dict,
N: int = 8,
ucg_keys: List[str] = None,
only_log_video_latents=False,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = dict()
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
if not self.latent_input:
log["inputs"] = x.to(torch.float32)
x = x.permute(0, 2, 1, 3, 4).contiguous()
z = self.encode_first_stage(x, batch)
if not only_log_video_latents:
log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
z = z.permute(0, 2, 1, 3, 4).contiguous()
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
if only_log_video_latents:
latents = 1.0 / self.scale_factor * samples
log["latents"] = latents
else:
samples = self.decode_first_stage(samples).to(torch.float32)
samples = samples.permute(0, 2, 1, 3, 4).contiguous()
log["samples"] = samples
return log
from functools import partial
from einops import rearrange, repeat
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from sat.model.base_model import BaseModel, non_conflict
from sat.model.mixins import BaseMixin
from sat.transformer_defaults import HOOKS_DEFAULT, attention_fn_default
from sat.mpu.layers import ColumnParallelLinear
from sgm.util import instantiate_from_config
from sgm.modules.diffusionmodules.openaimodel import Timestep
from sgm.modules.diffusionmodules.util import (
linear,
timestep_embedding,
)
from sat.ops.layernorm import LayerNorm, RMSNorm
class ImagePatchEmbeddingMixin(BaseMixin):
def __init__(
self,
in_channels,
hidden_size,
patch_size,
bias=True,
text_hidden_size=None,
):
super().__init__()
# print(in_channels)
# self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
self.proj_sr = nn.Conv2d(in_channels * 2, hidden_size, kernel_size=patch_size, stride=patch_size, bias=bias)
# 复制原始层前16个通道的权重
# self.proj_sr.weight.data[:, :in_channels, :, :] = self.proj.weight.data.clone()
# # 将后16个通道的权重初始化为零
# torch.nn.init.constant_(self.proj_sr.weight.data[:, in_channels:, :, :], 0)
# # 如果使用了 bias,直接复制原有的 bias 值
# if bias:
# self.proj_sr.bias.data = self.proj.bias.data.clone()
if text_hidden_size is not None:
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
else:
self.text_proj = None
def word_embedding_forward(self, input_ids, **kwargs):
# now is 3d patch
images = kwargs["images"] # (b,t,c,h,w)
B, T = images.shape[:2]
emb = images.view(-1, *images.shape[2:])
#--------
# Debug
#--------
# emb_ori = emb
# x_ori, _ = emb.chunk(2, dim=1)
# emb = self.proj(x_ori)
# emb_debug = self.proj_sr(emb_ori) # ((b t),d,h/2,w/2) [2 * 8, 16, 60, 90]
# print(torch.sqrt((emb - emb_debug)**2).mean())
emb = self.proj_sr(emb) # ((b t),d,h/2,w/2) [2 * 8, 32, 60, 90]
emb = emb.view(B, T, *emb.shape[1:])
emb = emb.flatten(3).transpose(2, 3) # (b,t,n,d)
emb = rearrange(emb, "b t n d -> b (t n) d")
if self.text_proj is not None:
text_emb = self.text_proj(kwargs["encoder_outputs"])
emb = torch.cat((text_emb, emb), dim=1) # (b,n_t+t*n_i,d)
emb = emb.contiguous()
return emb # (b,n_t+t*n_i,d)
def reinit(self, parent_model=None):
w = self.proj_sr.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.proj_sr.bias, 0)
del self.transformer.word_embeddings
def get_3d_sincos_pos_embed(
embed_dim,
grid_height,
grid_width,
t_size,
cls_token=False,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
"""
grid_size: int of the grid height and width
t_size: int of the temporal size
return:
pos_embed: [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
assert embed_dim % 4 == 0
embed_dim_spatial = embed_dim // 4 * 3
embed_dim_temporal = embed_dim // 4
# spatial
grid_h = np.arange(grid_height, dtype=np.float32) / height_interpolation
grid_w = np.arange(grid_width, dtype=np.float32) / width_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
# temporal
grid_t = np.arange(t_size, dtype=np.float32) / time_interpolation
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
# concate: [T, H, W] order
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
pos_embed_temporal = np.repeat(pos_embed_temporal, grid_height * grid_width, axis=1) # [T, H*W, D // 4]
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
pos_embed_spatial = np.repeat(pos_embed_spatial, t_size, axis=0) # [T, H*W, D // 4 * 3]
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1)
# pos_embed = pos_embed.reshape([-1, embed_dim]) # [T*H*W, D]
return pos_embed # [T, H*W, D]
def get_2d_sincos_pos_embed(embed_dim, grid_height, grid_width, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_height, dtype=np.float32)
grid_w = np.arange(grid_width, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_height, grid_width])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Basic3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
text_length=0,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
):
super().__init__()
self.height = height
self.width = width
self.text_length = text_length
self.compressed_num_frames = compressed_num_frames
self.spatial_length = height * width
self.num_patches = height * width * compressed_num_frames
self.pos_embedding = nn.Parameter(
torch.zeros(1, int(text_length + self.num_patches), int(hidden_size)), requires_grad=False
)
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
def position_embedding_forward(self, position_ids, **kwargs):
if kwargs["images"].shape[1] == 1:
return self.pos_embedding[:, : self.text_length + self.spatial_length]
return self.pos_embedding[:, : self.text_length + kwargs["seq_length"]]
def reinit(self, parent_model=None):
del self.transformer.position_embeddings
pos_embed = get_3d_sincos_pos_embed(
self.pos_embedding.shape[-1],
self.height,
self.width,
self.compressed_num_frames,
height_interpolation=self.height_interpolation,
width_interpolation=self.width_interpolation,
time_interpolation=self.time_interpolation,
)
pos_embed = torch.from_numpy(pos_embed).float()
pos_embed = rearrange(pos_embed, "t n d -> (t n) d")
self.pos_embedding.data[:, -self.num_patches :].copy_(pos_embed)
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
class Rotary3DPositionEmbeddingMixin(BaseMixin):
def __init__(
self,
height,
width,
compressed_num_frames,
hidden_size,
hidden_size_head,
text_length,
theta=10000,
rot_v=False,
learnable_pos_embed=False,
):
super().__init__()
self.rot_v = rot_v
dim_t = hidden_size_head // 4
dim_h = hidden_size_head // 8 * 3
dim_w = hidden_size_head // 8 * 3
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2)[: (dim_t // 2)].float() / dim_t))
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2)[: (dim_h // 2)].float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2)[: (dim_w // 2)].float() / dim_w))
grid_t = torch.arange(compressed_num_frames, dtype=torch.float32)
grid_h = torch.arange(height, dtype=torch.float32)
grid_w = torch.arange(width, dtype=torch.float32)
freqs_t = torch.einsum("..., f -> ... f", grid_t, freqs_t)
freqs_h = torch.einsum("..., f -> ... f", grid_h, freqs_h)
freqs_w = torch.einsum("..., f -> ... f", grid_w, freqs_w)
freqs_t = repeat(freqs_t, "... n -> ... (n r)", r=2)
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
freqs = broadcat((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = rearrange(freqs, "t h w d -> (t h w) d")
freqs = freqs.contiguous()
freqs_sin = freqs.sin()
freqs_cos = freqs.cos()
self.register_buffer("freqs_sin", freqs_sin)
self.register_buffer("freqs_cos", freqs_cos)
self.text_length = text_length
if learnable_pos_embed:
num_patches = height * width * compressed_num_frames + text_length
self.pos_embedding = nn.Parameter(torch.zeros(1, num_patches, int(hidden_size)), requires_grad=True)
else:
self.pos_embedding = None
def rotary(self, t, **kwargs):
seq_len = t.shape[2]
freqs_cos = self.freqs_cos[:seq_len].unsqueeze(0).unsqueeze(0)
freqs_sin = self.freqs_sin[:seq_len].unsqueeze(0).unsqueeze(0)
return t * freqs_cos + rotate_half(t) * freqs_sin
def position_embedding_forward(self, position_ids, **kwargs):
if self.pos_embedding is not None:
return self.pos_embedding[:, :self.text_length + kwargs["seq_length"]]
else:
return None
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
**kwargs,
):
attention_fn_default = HOOKS_DEFAULT["attention_fn"]
query_layer[:, :, self.text_length :] = self.rotary(query_layer[:, :, self.text_length :])
key_layer[:, :, self.text_length :] = self.rotary(key_layer[:, :, self.text_length :])
if self.rot_v:
value_layer[:, :, self.text_length :] = self.rotary(value_layer[:, :, self.text_length :])
return attention_fn_default(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def unpatchify(x, c, p, w, h, rope_position_ids=None, **kwargs):
"""
x: (N, T/2 * S, patch_size**3 * C)
imgs: (N, T, H, W, C)
"""
if rope_position_ids is not None:
assert NotImplementedError
# do pix2struct unpatchify
L = x.shape[1]
x = x.reshape(shape=(x.shape[0], L, p, p, c))
x = torch.einsum("nlpqc->ncplq", x)
imgs = x.reshape(shape=(x.shape[0], c, p, L * p))
else:
b = x.shape[0]
imgs = rearrange(x, "b (t h w) (c p q) -> b t c (h p) (w q)", b=b, h=h, w=w, c=c, p=p, q=p)
return imgs
class FinalLayerMixin(BaseMixin):
def __init__(
self,
hidden_size,
time_embed_dim,
patch_size,
out_channels,
latent_width,
latent_height,
elementwise_affine,
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.out_channels = out_channels
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=elementwise_affine, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 2 * hidden_size, bias=True))
self.spatial_length = latent_width * latent_height // patch_size**2
self.latent_width = latent_width
self.latent_height = latent_height
def final_forward(self, logits, **kwargs):
x, emb = logits[:, kwargs["text_length"] :, :], kwargs["emb"] # x:(b,(t n),d)
shift, scale = self.adaLN_modulation(emb).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return unpatchify(
x,
c=self.out_channels,
p=self.patch_size,
w=self.latent_width // self.patch_size,
h=self.latent_height // self.patch_size,
rope_position_ids=kwargs.get("rope_position_ids", None),
**kwargs,
)
def reinit(self, parent_model=None):
nn.init.xavier_uniform_(self.linear.weight)
nn.init.constant_(self.linear.bias, 0)
class SwiGLUMixin(BaseMixin):
def __init__(self, num_layers, in_features, hidden_features, bias=False):
super().__init__()
self.w2 = nn.ModuleList(
[
ColumnParallelLinear(
in_features,
hidden_features,
gather_output=False,
bias=bias,
module=self,
name="dense_h_to_4h_gate",
)
for i in range(num_layers)
]
)
def mlp_forward(self, hidden_states, **kw_args):
x = hidden_states
origin = self.transformer.layers[kw_args["layer_id"]].mlp
x1 = origin.dense_h_to_4h(x)
x2 = self.w2[kw_args["layer_id"]](x)
hidden = origin.activation_func(x2) * x1
x = origin.dense_4h_to_h(hidden)
return x
class AdaLNMixin(BaseMixin):
def __init__(
self,
width,
height,
hidden_size,
num_layers,
time_embed_dim,
compressed_num_frames,
qk_ln=True,
hidden_size_head=None,
elementwise_affine=True,
):
super().__init__()
self.num_layers = num_layers
self.width = width
self.height = height
self.compressed_num_frames = compressed_num_frames
self.adaLN_modulations = nn.ModuleList(
[nn.Sequential(nn.SiLU(), nn.Linear(time_embed_dim, 12 * hidden_size)) for _ in range(num_layers)]
)
self.qk_ln = qk_ln
if qk_ln:
self.query_layernorm_list = nn.ModuleList(
[
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
for _ in range(num_layers)
]
)
self.key_layernorm_list = nn.ModuleList(
[
LayerNorm(hidden_size_head, eps=1e-6, elementwise_affine=elementwise_affine)
for _ in range(num_layers)
]
)
def layer_forward(
self,
hidden_states,
mask,
*args,
**kwargs,
):
text_length = kwargs["text_length"]
# hidden_states (b,(n_t+t*n_i),d)
text_hidden_states = hidden_states[:, :text_length] # (b,n,d)
img_hidden_states = hidden_states[:, text_length:] # (b,(t n),d)
layer = self.transformer.layers[kwargs["layer_id"]]
adaLN_modulation = self.adaLN_modulations[kwargs["layer_id"]]
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
text_shift_msa,
text_scale_msa,
text_gate_msa,
text_shift_mlp,
text_scale_mlp,
text_gate_mlp,
) = adaLN_modulation(kwargs["emb"]).chunk(12, dim=1)
gate_msa, gate_mlp, text_gate_msa, text_gate_mlp = (
gate_msa.unsqueeze(1),
gate_mlp.unsqueeze(1),
text_gate_msa.unsqueeze(1),
text_gate_mlp.unsqueeze(1),
)
# self full attention (b,(t n),d) b: batchsize; (t n): temp & spa; d: hidden_size
img_attention_input = layer.input_layernorm(img_hidden_states)
text_attention_input = layer.input_layernorm(text_hidden_states)
img_attention_input = modulate(img_attention_input, shift_msa, scale_msa)
text_attention_input = modulate(text_attention_input, text_shift_msa, text_scale_msa)
# Spatial LIEM
_, thw, _ = img_attention_input.shape
t = thw // (self.height * self.width)
spa_fea = rearrange(img_attention_input, 'b (t h w) c -> (b t) c h w', h=self.height, w=self.width)
spa_fea = layer.spa_local(spa_fea)
# Temporal LIEM
temp_fea = rearrange(spa_fea, '(b t) c h w -> (b h w) t c', t=t)
temp_fea = layer.temp_local(temp_fea)
img_attention_input = rearrange(temp_fea, '(b h w) t c -> b (t h w) c', h=self.height, w=self.width)
attention_input = torch.cat((text_attention_input, img_attention_input), dim=1) # (b,n_t+t*n_i,d)
attention_output = layer.attention(attention_input, mask, **kwargs)
text_attention_output = attention_output[:, :text_length] # (b,n,d)
img_attention_output = attention_output[:, text_length:] # (b,(t n),d)
if self.transformer.layernorm_order == "sandwich":
text_attention_output = layer.third_layernorm(text_attention_output)
img_attention_output = layer.third_layernorm(img_attention_output)
img_hidden_states = img_hidden_states + gate_msa * img_attention_output # (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_msa * text_attention_output # (b,n,d)
# mlp (b,(t n),d)
img_mlp_input = layer.post_attention_layernorm(img_hidden_states) # vision (b,(t n),d)
text_mlp_input = layer.post_attention_layernorm(text_hidden_states) # language (b,n,d)
img_mlp_input = modulate(img_mlp_input, shift_mlp, scale_mlp)
text_mlp_input = modulate(text_mlp_input, text_shift_mlp, text_scale_mlp)
mlp_input = torch.cat((text_mlp_input, img_mlp_input), dim=1) # (b,(n_t+t*n_i),d
mlp_output = layer.mlp(mlp_input, **kwargs)
img_mlp_output = mlp_output[:, text_length:] # vision (b,(t n),d)
text_mlp_output = mlp_output[:, :text_length] # language (b,n,d)
if self.transformer.layernorm_order == "sandwich":
text_mlp_output = layer.fourth_layernorm(text_mlp_output)
img_mlp_output = layer.fourth_layernorm(img_mlp_output)
img_hidden_states = img_hidden_states + gate_mlp * img_mlp_output # vision (b,(t n),d)
text_hidden_states = text_hidden_states + text_gate_mlp * text_mlp_output # language (b,n,d)
hidden_states = torch.cat((text_hidden_states, img_hidden_states), dim=1) # (b,(n_t+t*n_i),d)
return hidden_states
def reinit(self, parent_model=None):
for layer in self.adaLN_modulations:
nn.init.constant_(layer[-1].weight, 0)
nn.init.constant_(layer[-1].bias, 0)
@non_conflict
def attention_fn(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=None,
log_attention_weights=None,
scaling_attention_score=True,
old_impl=attention_fn_default,
**kwargs,
):
if self.qk_ln:
query_layernorm = self.query_layernorm_list[kwargs["layer_id"]]
key_layernorm = self.key_layernorm_list[kwargs["layer_id"]]
query_layer = query_layernorm(query_layer)
key_layer = key_layernorm(key_layer)
return old_impl(
query_layer,
key_layer,
value_layer,
attention_mask,
attention_dropout=attention_dropout,
log_attention_weights=log_attention_weights,
scaling_attention_score=scaling_attention_score,
**kwargs,
)
str_to_dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
class DiffusionTransformer(BaseModel):
def __init__(
self,
transformer_args,
num_frames,
time_compressed_rate,
latent_width,
latent_height,
patch_size,
in_channels,
out_channels,
hidden_size,
num_layers,
num_attention_heads,
elementwise_affine,
time_embed_dim=None,
num_classes=None,
modules={},
input_time="adaln",
adm_in_channels=None,
parallel_output=True,
height_interpolation=1.0,
width_interpolation=1.0,
time_interpolation=1.0,
use_SwiGLU=False,
use_RMSNorm=False,
zero_init_y_embed=False,
**kwargs,
):
self.latent_width = latent_width
self.latent_height = latent_height
self.patch_size = patch_size
self.num_frames = num_frames
self.time_compressed_rate = time_compressed_rate
self.spatial_length = latent_width * latent_height // patch_size**2
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.model_channels = hidden_size
self.time_embed_dim = time_embed_dim if time_embed_dim is not None else hidden_size
self.num_classes = num_classes
self.adm_in_channels = adm_in_channels
self.input_time = input_time
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.is_decoder = transformer_args.is_decoder
self.elementwise_affine = elementwise_affine
self.height_interpolation = height_interpolation
self.width_interpolation = width_interpolation
self.time_interpolation = time_interpolation
self.inner_hidden_size = hidden_size * 4
self.zero_init_y_embed = zero_init_y_embed
try:
self.dtype = str_to_dtype[kwargs.pop("dtype")]
except:
self.dtype = torch.float32
if use_SwiGLU:
kwargs["activation_func"] = F.silu
elif "activation_func" not in kwargs:
approx_gelu = nn.GELU(approximate="tanh")
kwargs["activation_func"] = approx_gelu
if use_RMSNorm:
kwargs["layernorm"] = RMSNorm
else:
kwargs["layernorm"] = partial(LayerNorm, elementwise_affine=elementwise_affine, eps=1e-6)
transformer_args.num_layers = num_layers
transformer_args.hidden_size = hidden_size
transformer_args.num_attention_heads = num_attention_heads
transformer_args.parallel_output = parallel_output
super().__init__(args=transformer_args, transformer=None, **kwargs)
module_configs = modules
self._build_modules(module_configs)
if use_SwiGLU:
self.add_mixin(
"swiglu", SwiGLUMixin(num_layers, hidden_size, self.inner_hidden_size, bias=False), reinit=True
)
def _build_modules(self, module_configs):
model_channels = self.hidden_size
# time_embed_dim = model_channels * 4
time_embed_dim = self.time_embed_dim
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = nn.Sequential(
Timestep(model_channels),
nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
),
)
elif self.num_classes == "sequential":
assert self.adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(self.adm_in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
)
if self.zero_init_y_embed:
nn.init.constant_(self.label_emb[0][2].weight, 0)
nn.init.constant_(self.label_emb[0][2].bias, 0)
else:
raise ValueError()
pos_embed_config = module_configs["pos_embed_config"]
self.add_mixin(
"pos_embed",
instantiate_from_config(
pos_embed_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size=self.hidden_size,
),
reinit=True,
)
patch_embed_config = module_configs["patch_embed_config"]
self.add_mixin(
"patch_embed",
instantiate_from_config(
patch_embed_config,
patch_size=self.patch_size,
hidden_size=self.hidden_size,
in_channels=self.in_channels,
),
reinit=True,
)
if self.input_time == "adaln":
adaln_layer_config = module_configs["adaln_layer_config"]
self.add_mixin(
"adaln_layer",
instantiate_from_config(
adaln_layer_config,
height=self.latent_height // self.patch_size,
width=self.latent_width // self.patch_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
compressed_num_frames=(self.num_frames - 1) // self.time_compressed_rate + 1,
hidden_size_head=self.hidden_size // self.num_attention_heads,
time_embed_dim=self.time_embed_dim,
elementwise_affine=self.elementwise_affine,
),
)
else:
raise NotImplementedError
final_layer_config = module_configs["final_layer_config"]
self.add_mixin(
"final_layer",
instantiate_from_config(
final_layer_config,
hidden_size=self.hidden_size,
patch_size=self.patch_size,
out_channels=self.out_channels,
time_embed_dim=self.time_embed_dim,
latent_width=self.latent_width,
latent_height=self.latent_height,
elementwise_affine=self.elementwise_affine,
),
reinit=True,
)
if "lora_config" in module_configs:
lora_config = module_configs["lora_config"]
self.add_mixin("lora", instantiate_from_config(lora_config, layer_num=self.num_layers), reinit=True)
return
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
# print('x shape:', x.shape) # train phase: torch.Size([2, 8, 32, 60, 90])
b, t, d, h, w = x.shape
if x.dtype != self.dtype:
x = x.to(self.dtype)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False, dtype=self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
# assert y.shape[0] == x.shape[0]
assert x.shape[0] % y.shape[0] == 0
y = y.repeat_interleave(x.shape[0] // y.shape[0], dim=0)
emb = emb + self.label_emb(y)
kwargs["seq_length"] = t * h * w // (self.patch_size**2)
kwargs["images"] = x
kwargs["emb"] = emb
kwargs["encoder_outputs"] = context
kwargs["text_length"] = context.shape[1]
kwargs["input_ids"] = kwargs["position_ids"] = kwargs["attention_mask"] = torch.ones((1, 1)).to(x.dtype)
output = super().forward(**kwargs)[0]
return output
\ No newline at end of file
#! /bin/bash
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
environs="WORLD_SIZE=1 RANK=0 LOCAL_RANK=0 LOCAL_WORLD_SIZE=1"
run_cmd="$environs python sample_sr.py --base configs/cogvideox_5b/cogvideox_5b_infer_sr.yaml"
echo ${run_cmd}
eval ${run_cmd}
echo "DONE on `hostname`"
\ No newline at end of file
SwissArmyTransformer==0.4.12
omegaconf==2.3.0
torch==2.4.0
torchvision==0.19.0
pytorch_lightning==2.3.3
kornia==0.7.3
beartype==0.18.5
numpy==2.0.1
fsspec==2024.5.0
safetensors==0.4.3
imageio-ffmpeg==0.5.1
imageio==2.34.2
# scipy==1.14.0
decord==0.6.0
wandb==0.17.5
deepspeed==0.14.4
\ No newline at end of file
import os
import math
import argparse
from typing import List, Union
from tqdm import tqdm
from omegaconf import ListConfig
import imageio
import torch
from einops import rearrange
import numpy as np
from einops import rearrange
import torchvision.transforms as TT
from sat.model.base_model import get_model
from sat.training.model_io import load_checkpoint
from sat import mpu
from diffusion_video import SATVideoDiffusionEngine
from arguments import get_args
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
from data_video import PairedCaptionDataset
from color_fix import adain_color_fix
def read_from_cli():
cnt = 0
try:
while True:
x = input("Please input English text (Ctrl-D quit): ")
yield x.strip(), cnt
cnt += 1
except EOFError as e:
pass
def read_from_file(p, rank=0, world_size=1):
with open(p, "r") as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def save_video_as_grid_and_mp4(video_batch: torch.Tensor, save_path: str, fps: int = 5, args=None, key=None):
os.makedirs(save_path, exist_ok=True)
for i, vid in enumerate(video_batch):
gif_frames = []
for frame in vid:
frame = rearrange(frame, "c h w -> h w c")
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
gif_frames.append(frame)
now_save_path = os.path.join(save_path, f"{i:06d}.mp4")
with imageio.get_writer(now_save_path, fps=fps, quality=10) as writer:
for frame in gif_frames:
writer.append_data(frame)
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"):
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)
h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)
delta_h = h - image_size[0]
delta_w = w - image_size[1]
if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr
def sampling_main(args, model_cls):
test_dataset = PairedCaptionDataset(data_dir='/mnt/bn/videodataset/VSR/dataset/VSRTest/cogvideox_test',
null_text_ratio=0, num_frames=25)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=8,
batch_size=1,
shuffle=False
)
if isinstance(model_cls, type):
model = get_model(args, model_cls)
else:
model = model_cls
load_checkpoint(model, args)
model.eval()
if args.input_type == "cli":
data_iter = read_from_cli()
elif args.input_type == "txt":
rank, world_size = mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size()
print("rank and world_size", rank, world_size)
data_iter = read_from_file(args.input_file, rank=rank, world_size=world_size)
else:
raise NotImplementedError
image_size = [480, 720]
sample_func = model.sample_sr
T, H, W, C, F = args.sampling_num_frames, image_size[0], image_size[1], args.latent_channels, 8
num_samples = [1]
force_uc_zero_embeddings = ["txt"]
device = model.device
with torch.no_grad():
for step, batch in enumerate(test_dataloader):
cnt = step
gt = batch['mp4']
text = batch['txt']
lq = batch['lq']
fps = batch['fps']
# reload model on GPU
model.to(device)
print("rank:", rank, "start to process", text, cnt)
# TODO: broadcast image2video
value_dict = {
"prompt": text,
"negative_prompt": "",
"num_frames": torch.tensor(T).unsqueeze(0),
}
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, num_samples
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc))
for index in range(args.batch_size):
# reload model on GPU
model.to(device)
samples_z = sample_func(
c,
uc=uc,
batch_size=1,
shape=(T, C, H // F, W // F),
lq=lq,
)
samples_z = samples_z.permute(0, 2, 1, 3, 4).contiguous()
# print('max samples_z:', torch.max(samples_z)) # 3.0996
# print('min samples_z:', torch.min(samples_z)) # -3.0742
# Unload the model from GPU to save GPU memory
model.to("cpu")
torch.cuda.empty_cache()
first_stage_model = model.first_stage_model
first_stage_model = first_stage_model.to(device)
latent = 1.0 / model.scale_factor * samples_z
# Decode latent serial to save GPU memory
print('latent shape:', latent.shape)
recons = []
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(), clear_fake_cp_cache=clear_fake_cp_cache
)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
# Using color fix
samples = adain_color_fix(samples, gt) # samples,lq: (b, t, c, h, w)
save_path = os.path.join(
args.output_dir, str(cnt) + "_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
save_path_gt = os.path.join(
args.output_dir, str(cnt) + "_gt_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
save_path_lq = os.path.join(
args.output_dir, str(cnt) + "_lq_" + text[0].replace(" ", "_").replace("/", "")[:120]
)
if mpu.get_model_parallel_rank() == 0:
save_video_as_grid_and_mp4(samples, save_path, fps=float(fps))
# save_video_as_grid_and_mp4(torch.clamp((gt + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_gt, fps=float(fps))
# save_video_as_grid_and_mp4(torch.clamp((lq + 1.0) / 2.0, min=0.0, max=1.0).cpu(), save_path_lq, fps=float(fps))
if __name__ == "__main__":
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
py_parser = argparse.ArgumentParser(add_help=False)
known, args_list = py_parser.parse_known_args()
args = get_args(args_list)
args = argparse.Namespace(**vars(args), **vars(known))
del args.deepspeed_config
args.model_config.first_stage_config.params.cp_size = 1
args.model_config.network_config.params.transformer_args.model_parallel_size = 1
args.model_config.network_config.params.transformer_args.checkpoint_activations = False
args.model_config.loss_fn_config.params.sigma_sampler_config.params.uniform_sampling = False
sampling_main(args, model_cls=SATVideoDiffusionEngine)
from .models import AutoencodingEngine
from .util import get_configs_path, instantiate_from_config
__version__ = "0.1.0"
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = (
self.f_min[cycle]
+ (self.f_max[cycle] - self.f_min[cycle])
* (self.cycle_lengths[cycle] - n)
/ (self.cycle_lengths[cycle])
)
self.last_f = f
return f
from .autoencoder import AutoencodingEngine
import logging
import math
import re
import random
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed
import torch.nn as nn
from einops import rearrange
from packaging import version
from ..modules.autoencoding.regularizers import AbstractRegularizer
from ..modules.ema import LitEma
from ..util import (
default,
get_nested_attribute,
get_obj_from_str,
instantiate_from_config,
initialize_context_parallel,
get_context_parallel_group,
get_context_parallel_group_rank,
is_context_parallel_initialized,
)
from ..modules.cp_enc_dec import _conv_split, _conv_gather
logpy = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = "jpg",
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if version.parse(torch.__version__) >= version.parse("2.0.0"):
self.automatic_optimization = False
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
if isinstance(ckpt, str):
ckpt = {
"target": "sgm.modules.checkpoint.CheckpointEngine",
"params": {"ckpt_path": ckpt},
}
engine = instantiate_from_config(ckpt)
engine(self)
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called")
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", dict()))
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.automatic_optimization = False # pytorch lightning
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
self.regularization: AbstractRegularizer = instantiate_from_config(regularizer_config)
self.optimizer_config = default(optimizer_config, {"target": "torch.optim.Adam"})
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = []
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, "get_trainable_parameters"):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
if hasattr(self.loss, "get_trainable_parameters"):
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x, **kwargs)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(self, x: torch.Tensor, **additional_decode_kwargs) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self, batch: dict, batch_idx: int, optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": optimizer_idx,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "train",
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {"train/loss/rec": aeloss.detach()}
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
"loss",
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
else:
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(batch, batch_idx, optimizer_idx=optimizer_idx)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if hasattr(self.loss, "forward_keys"):
extra_info = {
"z": z,
"optimizer_idx": 0,
"global_step": self.global_step,
"last_layer": self.get_last_layer(),
"split": "val" + postfix,
"regularization_log": regularization_log,
"autoencoder": self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
full_log_dict = log_dict_ae
if "optimizer_idx" in extra_info:
extra_info["optimizer_idx"] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f"val{postfix}/loss/rec",
log_dict_ae[f"val{postfix}/loss/rec"],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
def get_param_groups(
self, parameter_names: List[List[str]], optimizer_args: List[dict]
) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(f"Did not find parameters for pattern {pattern_}")
params.extend(pattern_params)
groups.append({"params": params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(self.trainable_ae_params, self.ae_optimizer_args)
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(self.trainable_disc_params, self.disc_optimizer_args)
logpy.info(f"Number of trainable discriminator parameters: {num_disc_params:,}")
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(disc_params, self.learning_rate, self.optimizer_config)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({key: batch[key] for key in self.additional_decode_keys.intersection(batch)})
_, xrec, _ = self(x, **additional_decode_kwargs)
log["inputs"] = x
log["reconstructions"] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log["diff"] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log["diff_boost"] = 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
if hasattr(self.loss, "log_images"):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log["reconstructions_ema"] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log["diff_ema"] = 2.0 * diff_ema - 1.0
log["diff_boost_ema"] = 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = "reconstructions-" + "-".join(
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
)
log[log_str] = xrec_add
return log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop("max_batch_size", None)
ddconfig = kwargs.pop("ddconfig")
ckpt_path = kwargs.pop("ckpt_path", None)
ckpt_engine = kwargs.pop("ckpt_engine", None)
super().__init__(
encoder_config={
"target": "sgm.modules.diffusionmodules.model.Encoder",
"params": ddconfig,
},
decoder_config={
"target": "sgm.modules.diffusionmodules.model.Decoder",
"params": ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
(1 + ddconfig["double_z"]) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(self, x: torch.Tensor, return_reg_log: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return
class VideoAutoencodingEngine(AutoencodingEngine):
def __init__(
self,
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list] = (),
image_video_weights=[1, 1],
only_train_decoder=False,
context_parallel_size=0,
**kwargs,
):
super().__init__(**kwargs)
self.context_parallel_size = context_parallel_size
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
if self.context_parallel_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.context_parallel_size)
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank() * self.context_parallel_size
torch.distributed.broadcast(batch, src=global_src_rank, group=get_context_parallel_group())
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch
return batch[self.input_key]
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print("Missing keys: ", missing_keys)
print("Unexpected keys: ", unexpected_keys)
print(f"Restored from {path}")
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
def __init__(
self,
cp_size=0,
*args,
**kwargs,
):
self.cp_size = cp_size
return super().__init__(*args, **kwargs)
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
input_cp: bool = False,
output_cp: bool = False,
use_cp: bool = True,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(x, src=global_src_rank, group=get_context_parallel_group())
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
else:
z = super().encode(x, return_reg_log, unregularized, use_cp=use_cp)
if self.cp_size > 0 and use_cp and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
return z, reg_log
return z
def decode(
self,
z: torch.Tensor,
input_cp: bool = False,
output_cp: bool = False,
use_cp: bool = True,
**kwargs,
):
if self.cp_size <= 1:
use_cp = False
if self.cp_size > 0 and use_cp and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z, src=global_src_rank, group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=1)
x = super().decode(z, use_cp=use_cp, **kwargs)
if self.cp_size > 0 and use_cp and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=1)
return x
def forward(
self,
x: torch.Tensor,
input_cp: bool = False,
latent_cp: bool = False,
output_cp: bool = False,
**additional_decode_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True, input_cp=input_cp, output_cp=latent_cp)
dec = self.decode(z, input_cp=latent_cp, output_cp=output_cp, **additional_decode_kwargs)
return z, dec, reg_log
from .encoders.modules import GeneralConditioner
UNCONDITIONAL_CONFIG = {
"target": "sgm.modules.GeneralConditioner",
"params": {"emb_models": []},
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment