Commit 3b804999 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2420 failed with stages
in 0 seconds
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
# device=self.parameters.device
# )
x = self.mean + self.std * torch.randn_like(
self.mean).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = (x if isinstance(x, torch.Tensor) else
torch.tensor(x).to(tensor) for x in (logvar1, logvar2))
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,
(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
shadow_params[sname].sub_(
one_minus_decay *
(shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
import math
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
import kornia
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from torch.utils.checkpoint import checkpoint
from transformers import T5EncoderModel, T5Tokenizer
from ...util import (append_dims, autocast, count_params, default,
disabled_train, expand_dims_like, instantiate_from_config)
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: 'vector', 3: 'crossattn', 4: 'concat', 5: 'concat'}
KEY2CATDIM = {'vector': 1, 'crossattn': 2, 'concat': 1}
def __init__(self,
emb_models: Union[List, ListConfig],
cor_embs=[],
cor_p=[]):
super().__init__()
embedders = []
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f'embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel'
embedder.is_trainable = embconfig.get('is_trainable', False)
embedder.ucg_rate = embconfig.get('ucg_rate', 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f'Initialized embedder #{n}: {embedder.__class__.__name__} '
f'with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}'
)
if 'input_key' in embconfig:
embedder.input_key = embconfig['input_key']
elif 'input_keys' in embconfig:
embedder.input_keys = embconfig['input_keys']
else:
raise KeyError(
f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
)
embedder.legacy_ucg_val = embconfig.get('legacy_ucg_value', None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
if len(cor_embs) > 0:
assert len(cor_p) == 2**len(cor_embs)
self.cor_embs = cor_embs
self.cor_p = cor_p
def possibly_get_ucg_val(self, embedder: AbstractEmbModel,
batch: Dict) -> Dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def surely_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict,
cond_or_not) -> Dict:
assert embedder.legacy_ucg_val is not None
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if cond_or_not[i]:
batch[embedder.input_key][i] = val
return batch
def get_single_embedding(
self,
embedder,
batch,
output,
cond_or_not: Optional[np.ndarray] = None,
force_zero_embeddings: Optional[List] = None,
):
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, 'input_key') and (embedder.input_key
is not None):
if embedder.legacy_ucg_val is not None:
if cond_or_not is None:
batch = self.possibly_get_ucg_val(embedder, batch)
else:
batch = self.surely_get_ucg_val(
embedder, batch, cond_or_not)
emb_out = embedder(batch[embedder.input_key])
elif hasattr(embedder, 'input_keys'):
emb_out = embedder(*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f'encoder outputs must be tensors or a sequence, but got {type(emb_out)}'
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
if cond_or_not is None:
emb = (expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate) *
torch.ones(emb.shape[0], device=emb.device)),
emb,
) * emb)
else:
emb = (expand_dims_like(
torch.tensor(1 - cond_or_not,
dtype=emb.dtype,
device=emb.device),
emb,
) * emb)
if hasattr(embedder, 'input_key'
) and embedder.input_key in force_zero_embeddings:
emb = torch.zeros_like(emb)
if out_key in output:
output[out_key] = torch.cat((output[out_key], emb),
self.KEY2CATDIM[out_key])
else:
output[out_key] = emb
return output
def forward(self,
batch: Dict,
force_zero_embeddings: Optional[List] = None) -> Dict:
output = dict()
if force_zero_embeddings is None:
force_zero_embeddings = []
if len(self.cor_embs) > 0:
batch_size = len(batch[list(batch.keys())[0]])
rand_idx = np.random.choice(len(self.cor_p),
size=(batch_size, ),
p=self.cor_p)
for emb_idx in self.cor_embs:
cond_or_not = rand_idx % 2
rand_idx //= 2
output = self.get_single_embedding(
self.embedders[emb_idx],
batch,
output=output,
cond_or_not=cond_or_not,
force_zero_embeddings=force_zero_embeddings,
)
for i, embedder in enumerate(self.embedders):
if i in self.cor_embs:
continue
output = self.get_single_embedding(
embedder,
batch,
output=output,
force_zero_embeddings=force_zero_embeddings)
return output
def get_unconditional_conditioning(self,
batch_c,
batch_uc=None,
force_uc_zero_embeddings=None):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
cor_embs = self.cor_embs
cor_p = self.cor_p
self.cor_embs = []
self.cor_p = []
c = self(batch_c)
uc = self(batch_c if batch_uc is None else batch_uc,
force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
self.cor_embs = cor_embs
self.cor_p = cor_p
return c, uc
class FrozenT5Embedder(AbstractEmbModel):
"""Uses the T5 transformer encoder for text"""
def __init__(
self,
model_dir='google/t5-v1_1-xxl',
device='cuda',
max_length=77,
freeze=True,
cache_dir=None,
):
super().__init__()
if model_dir != 'google/t5-v1_1-xxl':
self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
self.transformer = T5EncoderModel.from_pretrained(model_dir)
else:
self.tokenizer = T5Tokenizer.from_pretrained(model_dir,
cache_dir=cache_dir)
self.transformer = T5EncoderModel.from_pretrained(
model_dir, cache_dir=cache_dir)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
# @autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt',
)
tokens = batch_encoding['input_ids'].to(self.device)
with torch.autocast('cuda', enabled=False):
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import torch
from ..modules.attention import *
from ..modules.diffusionmodules.util import (AlphaBlender, linear,
timestep_embedding)
class TimeMixSequential(nn.Sequential):
def forward(self, x, context=None, timesteps=None):
for layer in self:
x = layer(x, context, timesteps)
return x
class VideoTransformerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention,
'softmax-xformers': MemoryEfficientCrossAttention,
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
timesteps=None,
ff_in=False,
inner_dim=None,
attn_mode='softmax',
disable_self_attn=False,
disable_temporal_crossattention=False,
switch_temporal_ca_to_sa=False,
):
super().__init__()
attn_cls = self.ATTENTION_MODES[attn_mode]
self.ff_in = ff_in or inner_dim is not None
if inner_dim is None:
inner_dim = dim
assert int(n_heads * d_head) == inner_dim
self.is_res = inner_dim == dim
if self.ff_in:
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(dim,
dim_out=inner_dim,
dropout=dropout,
glu=gated_ff)
self.timesteps = timesteps
self.disable_self_attn = disable_self_attn
if self.disable_self_attn:
self.attn1 = attn_cls(
query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
context_dim=context_dim,
dropout=dropout,
) # is a cross-attention
else:
self.attn1 = attn_cls(query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is a self-attention
self.ff = FeedForward(inner_dim,
dim_out=dim,
dropout=dropout,
glu=gated_ff)
if disable_temporal_crossattention:
if switch_temporal_ca_to_sa:
raise ValueError
else:
self.attn2 = None
else:
self.norm2 = nn.LayerNorm(inner_dim)
if switch_temporal_ca_to_sa:
self.attn2 = attn_cls(query_dim=inner_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is a self-attention
else:
self.attn2 = attn_cls(
query_dim=inner_dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(inner_dim)
self.norm3 = nn.LayerNorm(inner_dim)
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
self.checkpoint = checkpoint
if self.checkpoint:
print(f'{self.__class__.__name__} is using checkpointing')
def forward(self,
x: torch.Tensor,
context: torch.Tensor = None,
timesteps: int = None) -> torch.Tensor:
if self.checkpoint:
return checkpoint(self._forward, x, context, timesteps)
else:
return self._forward(x, context, timesteps=timesteps)
def _forward(self, x, context=None, timesteps=None):
assert self.timesteps or timesteps
assert not (self.timesteps
and timesteps) or self.timesteps == timesteps
timesteps = self.timesteps or timesteps
B, S, C = x.shape
x = rearrange(x, '(b t) s c -> (b s) t c', t=timesteps)
if self.ff_in:
x_skip = x
x = self.ff_in(self.norm_in(x))
if self.is_res:
x += x_skip
if self.disable_self_attn:
x = self.attn1(self.norm1(x), context=context) + x
else:
x = self.attn1(self.norm1(x)) + x
if self.attn2 is not None:
if self.switch_temporal_ca_to_sa:
x = self.attn2(self.norm2(x)) + x
else:
x = self.attn2(self.norm2(x), context=context) + x
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = rearrange(x,
'(b s) t c -> (b t) s c',
s=S,
b=B // timesteps,
c=C,
t=timesteps)
return x
def get_last_layer(self):
return self.ff.net[-1].weight
str_to_dtype = {
'fp32': torch.float32,
'fp16': torch.float16,
'bf16': torch.bfloat16
}
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = 'fixed',
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
attn_mode='softmax',
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype='fp32',
):
super().__init__(
in_channels,
n_heads,
d_head,
depth=depth,
dropout=dropout,
attn_type=attn_mode,
use_checkpoint=checkpoint,
context_dim=context_dim,
use_linear=use_linear,
disable_self_attn=disable_self_attn,
)
self.time_depth = time_depth
self.depth = depth
self.max_time_embed_period = max_time_embed_period
time_mix_d_head = d_head
n_time_mix_heads = n_heads
time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
inner_dim = n_heads * d_head
if use_spatial_context:
time_context_dim = context_dim
self.time_stack = nn.ModuleList([
VideoTransformerBlock(
inner_dim,
n_time_mix_heads,
time_mix_d_head,
dropout=dropout,
context_dim=time_context_dim,
timesteps=timesteps,
checkpoint=checkpoint,
ff_in=ff_in,
inner_dim=time_mix_inner_dim,
attn_mode=attn_mode,
disable_self_attn=disable_self_attn,
disable_temporal_crossattention=disable_temporal_crossattention,
) for _ in range(self.depth)
])
assert len(self.time_stack) == len(self.transformer_blocks)
self.use_spatial_context = use_spatial_context
self.in_channels = in_channels
time_embed_dim = self.in_channels * 4
self.time_pos_embed = nn.Sequential(
linear(self.in_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, self.in_channels),
)
self.time_mixer = AlphaBlender(alpha=merge_factor,
merge_strategy=merge_strategy)
self.dtype = str_to_dtype[dtype]
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
spatial_context = None
if exists(context):
spatial_context = context
if self.use_spatial_context:
assert context.ndim == 3, f'n dims of spatial context should be 3 but are {context.ndim}'
time_context = context
time_context_first_timestep = time_context[::timesteps]
time_context = repeat(time_context_first_timestep,
'b ... -> (b n) ...',
n=h * w)
elif time_context is not None and not self.use_spatial_context:
time_context = repeat(time_context, 'b ... -> (b n) ...', n=h * w)
if time_context.ndim == 2:
time_context = rearrange(time_context, 'b c -> b 1 c')
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
if self.use_linear:
x = self.proj_in(x)
num_frames = torch.arange(timesteps, device=x.device)
num_frames = repeat(num_frames, 't -> b t', b=x.shape[0] // timesteps)
num_frames = rearrange(num_frames, 'b t -> (b t)')
t_emb = timestep_embedding(
num_frames,
self.in_channels,
repeat_only=False,
max_period=self.max_time_embed_period,
dtype=self.dtype,
)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)):
x = block(
x,
context=spatial_context,
)
x_mix = x
x_mix = x_mix + emb
x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
x = self.time_mixer(
x_spatial=x,
x_temporal=x_mix,
image_only_indicator=image_only_indicator,
)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
import functools
import importlib
import os
from functools import partial
from inspect import isfunction
import fsspec
import numpy as np
import torch
import torch.distributed
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def set_context_parallel_group(size, group):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_SIZE = size
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, context_parallel_size):
ranks = range(i, i + context_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
rank = torch.distributed.get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
assert _CONTEXT_PARALLEL_SIZE is not None, 'context parallel size is not initialized'
rank = torch.distributed.get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
class SafeConv3d(torch.nn.Conv3d):
def forward(self, input):
memory_count = torch.prod(torch.tensor(
input.shape)).item() * 2 / 1024**3
if memory_count > 2:
# print(f"WARNING: Conv3d with {memory_count:.2f}GB")
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
input_chunks = torch.chunk(input, part_num, dim=2) # NCTHW
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [
torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1:],
input_chunks[i]),
dim=2) for i in range(1, len(input_chunks))
]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super().forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super().forward(input)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def get_string_from_tuple(s):
try:
# Check if the string starts and ends with parentheses
if s[0] == '(' and s[-1] == ')':
# Convert the string to a tuple
t = eval(s)
# Check if the type of t is tuple
if type(t) == tuple:
return t[0]
else:
pass
except:
pass
return s
def is_power_of_two(n):
"""
chat.openai.com/chat
Return True if n is a power of 2, otherwise return False.
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
"""
if n <= 0:
return False
return (n & (n - 1)) == 0
def autocast(f, enabled=True):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=enabled,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled(),
):
return f(*args, **kwargs)
return do_autocast
def load_partial_from_config(config):
return partial(get_obj_from_str(config['target']),
**config.get('params', dict()))
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
if isinstance(xc[bi], list):
text_seq = xc[bi][0]
else:
text_seq = xc[bi]
lines = '\n'.join(text_seq[start:start + nc]
for start in range(0, len(text_seq), nc))
try:
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def partialclass(cls, *args, **kwargs):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
return NewCls
def make_path_absolute(path):
fs, p = fsspec.core.url_to_fs(path)
if fs.protocol == 'file':
return os.path.abspath(p)
return path
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def isheatmap(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 2
def isneighbors(x):
if not isinstance(x, torch.Tensor):
return False
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
def exists(x):
return x is not None
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config, **extra_kwargs):
if not 'target' in config:
if config == '__is_first_stage__':
return None
elif config == '__is_unconditional__':
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(**config.get('params', dict()),
**extra_kwargs)
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit('.', 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
)
return x[(..., ) + (None, ) * dims_to_append]
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f'Loading model from {ckpt}')
if ckpt.endswith('ckpt'):
pl_sd = torch.load(ckpt, map_location='cpu')
if 'global_step' in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd['state_dict']
elif ckpt.endswith('safetensors'):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print(u)
if freeze:
for param in model.parameters():
param.requires_grad = False
model.eval()
return model
def get_configs_path() -> str:
"""
Get the `configs` directory.
For a working copy, this is the one in the root of the repository,
but for an installed copy, it's in the `sgm` package (see pyproject.toml).
"""
this_dir = os.path.dirname(__file__)
candidates = (
os.path.join(this_dir, 'configs'),
os.path.join(this_dir, '..', 'configs'),
)
for candidate in candidates:
candidate = os.path.abspath(candidate)
if os.path.isdir(candidate):
return candidate
raise FileNotFoundError(f'Could not find SGM configs in {candidates}')
def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
"""
Will return the result of a recursive get attribute call.
E.g.:
a.b.c
= getattr(getattr(a, "b"), "c")
= get_nested_attribute(a, "b.c")
If any part of the attribute call is an integer x with current obj a, will
try to call a[x] instead of a.x first.
"""
attributes = attribute_path.split('.')
if depth is not None and depth > 0:
attributes = attributes[:depth]
assert len(attributes) > 0, 'At least one attribute should be selected'
current_attribute = obj
current_key = None
for level, attribute in enumerate(attributes):
current_key = '.'.join(attributes[:level + 1])
try:
id_ = int(attribute)
current_attribute = current_attribute[id_]
except ValueError:
current_attribute = getattr(current_attribute, attribute)
return (current_attribute,
current_key) if return_key else current_attribute
from math import sqrt
class SeededNoise:
def __init__(self, seeds, weights):
self.seeds = seeds
self.weights = weights
weight_square_sum = 0
for weight in weights:
weight_square_sum += weight**2
self.weight_square_sum_sqrt = sqrt(weight_square_sum)
self.cnt = 0
def __call__(self, x):
self.cnt += 1
randn_combined = torch.zeros_like(x)
for seed, weight in zip(self.seeds, self.weights):
randn = np.random.RandomState(seed + self.cnt).randn(*x.shape)
randn = torch.from_numpy(randn, dtype=x.dtype, device=x.device)
randn_combined += randn * weight
randn_combined /= self.weight_square_sum_sqrt
return randn_combined
import io
import json
import os
import re
import sys
import tarfile
from functools import partial
import webdataset as wds
from webdataset import DataPipeline, ResampledShards, tarfile_to_samples
from webdataset.filters import pipelinefilter
from webdataset.gopen import gopen, gopen_schemes
from webdataset.handlers import reraise_exception
from webdataset.tariterators import group_by_keys, url_opener
def pytorch_worker_info(group=None): # sourcery skip: use-contextlib-suppress
"""Return node and worker info for PyTorch and some distributed environments."""
rank = 0
world_size = 1
worker = 0
num_workers = 1
try:
import torch.distributed
if torch.distributed.is_available(
) and torch.distributed.is_initialized():
group = group or torch.distributed.group.WORLD
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
except ModuleNotFoundError:
pass
try:
import torch.utils.data
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
return rank, world_size, worker, num_workers
def pytorch_worker_seed(group=None):
"""Compute a distinct, deterministic RNG seed for each worker and node."""
rank, world_size, worker, num_workers = pytorch_worker_info(group=group)
return rank * 1000 + worker
def worker_seed_sat(group=None, seed=0):
return pytorch_worker_seed(group=group) + seed * 23
class ConfiguredResampledShards(ResampledShards):
def __init__(self, urls, seed, nshards=sys.maxsize, deterministic=True):
from sat.helpers import print_rank0
try:
from megatron.core.parallel_state import get_data_parallel_group
group = get_data_parallel_group()
print_rank0('Using megatron data parallel group.')
except:
from sat.mpu import get_data_parallel_group
try:
group = get_data_parallel_group()
print_rank0('Using sat data parallel group.')
except AssertionError:
group = None
print_rank0('No data parallel group is specified!')
worker_seed_sat_this = partial(worker_seed_sat, group=group, seed=seed)
super().__init__(urls, nshards, worker_seed_sat_this, deterministic)
class SimpleDistributedWebDataset(DataPipeline):
def __init__(self, path, process_fn, seed, *, shuffle_buffer=1000):
# set shuffle_buffer = 1 to disable it, model-parallel will be different due to shuffle
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(
path, seed), # Lots of shards are recommended, or not evenly
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
def tar_file_iterator_with_meta(fileobj,
meta_names,
skip_meta=r'__[^/]*__($|/)',
suffix=None,
handler=reraise_exception,
meta_stream=None):
"""Iterate over tar file, yielding filename, content pairs for the given tar stream.
:param fileobj: byte stream suitable for tarfile
:param meta_names: key of different items in meta file
:param skip_meta: regexp for keys that are skipped entirely (Default value = r"__[^/]*__($|/)")
"""
stream = tarfile.open(fileobj=fileobj, mode='r|*')
data_dir, filename = fileobj.name.rsplit('/', 1)
meta_data = {
} # {id: {meta_name: meta_value, meta_name2: meta_value2, ...}}
if meta_stream is None:
meta_file_name = filename.split('.')[0] + '.meta.jsonl'
meta_path = os.path.join(data_dir, meta_file_name)
if os.path.exists(meta_path):
meta_stream = open(meta_path)
else:
meta_file_name = meta_stream.name
if meta_stream is not None:
for lineno, line in enumerate(meta_stream):
meta_list = []
try:
meta_list.append(json.loads(line))
except Exception as exn:
from sat.helpers import print_rank0
print_rank0(
f'Error in loading jsonl {meta_file_name}, lineno {lineno}: {line}',
level='DEBUG')
continue
for item in meta_list:
if not item['key'] in meta_data:
meta_data[item['key']] = {}
for meta_name in meta_names:
if meta_name in item:
meta_data[item['key']][meta_name] = item[meta_name]
meta_stream.close()
try:
for tarinfo in stream:
fname = tarinfo.name
try:
if not tarinfo.isreg():
continue
if fname is None:
continue
if '/' not in fname and fname.startswith(
'__') and fname.endswith('__'):
# skipping metadata for now
continue
if skip_meta is not None and re.match(skip_meta, fname):
continue
if fname.endswith('.txt') and suffix is not None:
data = (stream.extractfile(tarinfo).read().decode() +
suffix).encode()
else:
data = stream.extractfile(tarinfo).read()
result = dict(fname=fname, data=data)
yield result
if fname.endswith('.id'):
fid = fname.split('.')[0]
if '-$#%@&' in fid:
sfid = fid.split('-$#%@&')[0]
else:
sfid = fid
meta_data_fid = meta_data.get(sfid, {})
for meta_name in meta_names:
meta_fname = fid + '.' + meta_name
meta = meta_data_fid.get(meta_name, None)
yield dict(fname=meta_fname, data=meta)
stream.members = []
except Exception as exn:
if hasattr(exn, 'args') and len(exn.args) > 0:
exn.args = (exn.args[0] + ' @ ' +
str(fileobj), ) + exn.args[1:]
if handler(exn):
continue
else:
break
except Exception as exn:
print(exn)
del stream
def tar_file_expander_with_meta(data, meta_names, handler=reraise_exception):
"""Expand a stream of open tar files into a stream of tar file contents.
This returns an iterator over (filename, file_contents).
"""
for source in data:
url = source['url']
try:
assert isinstance(source, dict)
assert 'stream' in source
for sample in tar_file_iterator_with_meta(
source['stream'],
meta_names,
meta_stream=source['meta_stream']):
assert isinstance(
sample, dict) and 'data' in sample and 'fname' in sample
sample['__url__'] = url
yield sample
except Exception as exn:
exn.args = exn.args + (source.get('stream'), source.get('url'))
if handler(exn):
continue
else:
break
def url_opener(
data,
handler,
**kw,
):
"""Open URLs and yield a stream of url+stream pairs.
Args:
data: iterator over dict(url=...)
handler: exception handler.
kw: keyword arguments for gopen.gopen.
Yields:
a stream of url+stream pairs.
"""
for sample in data:
assert isinstance(sample, dict), sample
assert 'url' in sample
url = sample['url']
try:
stream = gopen(url, **kw)
if hasattr(stream, 'meta_stream'):
meta_stream = stream.meta_stream
del stream.meta_stream
else:
meta_stream = None
sample.update(stream=stream, meta_stream=meta_stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url, )
if handler(exn):
continue
else:
break
def tarfile_samples_with_meta(src, meta_names, handler=reraise_exception):
streams = url_opener(src, handler=handler)
files = tar_file_expander_with_meta(streams, meta_names, handler)
samples = group_by_keys(files, handler=handler)
return samples
class MetaDistributedWebDataset(DataPipeline):
"""WebDataset with meta information files
Extra Format:
in webdataset (tar), for each sample there is a '.id';
for each tar file, there is a '.meta.jsonl' file with the same name;
The '.meta.jsonl' file contains lines of json objects, each with a 'key' field to match '.id'.
"""
def __init__(self,
path,
process_fn,
seed,
*,
meta_names=[],
nshards=sys.maxsize,
shuffle_buffer=1000,
include_dirs=None):
# os.environ['WDS_SHOW_SEED'] = '1'
import torch
if torch.distributed.get_rank() == 0:
if include_dirs is not None: # /webdatasets/A,/webdatasets/C
other_paths = []
include_dirs = include_dirs.split(',')
for include_dir in include_dirs:
if '*' in include_dir:
include_dir, n = include_dir.split('*')
n = int(n)
else:
n = 1
for cur_dir, dirs, files in os.walk(include_dir):
for f in files:
if f.endswith('tar') and os.path.getsize(
os.path.join(cur_dir, f)) > 0:
# other_paths.append(os.path.join(cur_dir,f))
other_paths.extend([os.path.join(cur_dir, f)] *
n)
# print(f'Adding dataset paths {",".join(other_paths)}')
from braceexpand import braceexpand
if len(path) > 0: # not ""
path = list(braceexpand(path)) + other_paths
else:
path = other_paths
path = [path]
else:
path = [
None,
]
torch.distributed.broadcast_object_list(path, src=0)
path = path[0]
tarfile_samples = partial(tarfile_samples_with_meta,
meta_names=meta_names)
tarfile_to_samples = pipelinefilter(tarfile_samples)
# if model parallel, shuffle_buffer should be 1 to disable shuffling
try:
from sat.mpu import get_model_parallel_world_size
if get_model_parallel_world_size() > 1:
shuffle_buffer = 1
except Exception:
pass
super().__init__(
ConfiguredResampledShards(path, seed, nshards=nshards),
tarfile_to_samples(),
wds.shuffle(shuffle_buffer),
process_fn,
)
# rclone support
from webdataset.gopen import Pipe
def gopen_rclone(url, mode='rb', bufsize=1024 * 1024 * 32):
"""Open a URL with `curl`.
:param url: rclone url, e.g. data:bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
url = url.replace('rclone://', '')
if mode[0] == 'r':
cmd = f"rclone cat '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 23],
) # skipcq: BAN-B604
elif mode[0] == 'w':
cmd = f"rclone cp - '{url}'"
return Pipe(
cmd,
mode=mode,
shell=True,
bufsize=bufsize,
ignore_status=[141, 26],
) # skipcq: BAN-B604
else:
raise ValueError(f'{mode}: unknown mode')
def gopen_boto3(url, mode='rb', bufsize=8192 * 2):
"""Open a URL with boto3 API.
:param url: boto3 url, e.g. boto3://bucket1/foo.tar. data should be configured.
:param mode: file mode
:param bufsize: buffer size
"""
import boto3
# boto3.set_stream_logger('botocore', level='DEBUG')
if url.startswith('boto3://'):
url = url.replace('boto3://', '')
need_meta = False
else:
url = url.replace('metaboto3://', '')
need_meta = True
endpoint_url = os.environ.get('S3_ENDPOINT_URL', None)
access_key = os.environ.get('S3_ACCESS_KEY_ID', None)
secret_key = os.environ.get('S3_SECRET_ACCESS_KEY', None)
if mode[0] == 'r':
s3_client = boto3.client('s3',
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key)
bucket, key = url.split('/', 1)
if need_meta:
# download a meta json
meta_file_key = key.split('.')[0] + '.meta.jsonl'
meta_stream = io.BytesIO()
s3_client.download_fileobj(bucket, meta_file_key, meta_stream)
meta_stream.seek(0)
meta_stream.name = meta_file_key
else:
meta_stream = None
# data tar stream
response = s3_client.get_object(Bucket=bucket,
Key=key) # Range optional
response['Body'].name = key # actually not used
response['Body'].meta_stream = meta_stream
return response['Body']
else:
raise ValueError(f'{mode}: unknown mode')
gopen_schemes['rclone'] = gopen_rclone
gopen_schemes['boto3'] = gopen_boto3
gopen_schemes['metaboto3'] = gopen_boto3
import math
from typing import List, Union
import numpy as np
import torch
from omegaconf import ListConfig
from sgm.util import instantiate_from_config
def read_from_file(p, rank=0, world_size=1):
with open(p) as fin:
cnt = -1
for l in fin:
cnt += 1
if cnt % world_size != rank:
continue
yield l.strip(), cnt
def disable_all_init():
"""Disable all redundant torch default initialization to accelerate model
creation."""
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
setattr(torch.nn.modules.sparse.Embedding, 'reset_parameters',
lambda self: None)
setattr(torch.nn.modules.conv.Conv2d, 'reset_parameters',
lambda self: None)
setattr(torch.nn.modules.normalization.GroupNorm, 'reset_parameters',
lambda self: None)
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
def get_batch(keys,
value_dict,
N: Union[List, ListConfig],
T=None,
device='cuda'):
batch = {}
batch_uc = {}
for key in keys:
if key == 'txt':
batch['txt'] = np.repeat([value_dict['prompt']],
repeats=math.prod(N)).reshape(N).tolist()
batch_uc['txt'] = np.repeat(
[value_dict['negative_prompt']],
repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]
if T is not None:
batch['num_video_frames'] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def decode(first_stage_model, latent):
first_stage_model.to(torch.float16)
latent = latent.to(torch.float16)
recons = []
T = latent.shape[2]
if T > 2:
loop_num = (T - 1) // 2
for i in range(loop_num):
if i == 0:
start_frame, end_frame = 0, 3
else:
start_frame, end_frame = i * 2 + 1, i * 2 + 3
if i == loop_num - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
with torch.no_grad():
recon = first_stage_model.decode(
latent[:, :, start_frame:end_frame].contiguous(),
clear_fake_cp_cache=clear_fake_cp_cache)
recons.append(recon)
else:
clear_fake_cp_cache = True
if latent.shape[2] > 1:
for m in first_stage_model.modules():
m.force_split = True
recon = first_stage_model.decode(
latent.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
recons.append(recon)
recon = torch.cat(recons, dim=2).to(torch.float32)
samples_x = recon.permute(0, 2, 1, 3, 4).contiguous()
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0).cpu()
samples = (samples * 255).squeeze(0).permute(0, 2, 3, 1)
save_frames = samples
return save_frames
def save_mem_decode(first_stage_model, latent):
l_h, l_w = latent.shape[3], latent.shape[4]
T = latent.shape[2]
F = 8
# split spatial along h w
num_h_splits = 1
num_w_splits = 2
ori_video = torch.zeros((1, 3, 1 + 4 * (T - 1), l_h * 8, l_w * 8),
device=latent.device)
for h_idx in range(num_h_splits):
for w_idx in range(num_w_splits):
start_h = h_idx * latent.shape[3] // num_h_splits
end_h = (h_idx + 1) * latent.shape[3] // num_h_splits
start_w = w_idx * latent.shape[4] // num_w_splits
end_w = (w_idx + 1) * latent.shape[4] // num_w_splits
latent_overlap = 16
if (start_h - latent_overlap >= 0) and (num_h_splits > 1):
real_start_h = start_h - latent_overlap
h_start_overlap = latent_overlap * F
else:
h_start_overlap = 0
real_start_h = start_h
if (end_h + latent_overlap <= l_h) and (num_h_splits > 1):
real_end_h = end_h + latent_overlap
h_end_overlap = latent_overlap * F
else:
h_end_overlap = 0
real_end_h = end_h
if (start_w - latent_overlap >= 0) and (num_w_splits > 1):
real_start_w = start_w - latent_overlap
w_start_overlap = latent_overlap * F
else:
w_start_overlap = 0
real_start_w = start_w
if (end_w + latent_overlap <= l_w) and (num_w_splits > 1):
real_end_w = end_w + latent_overlap
w_end_overlap = latent_overlap * F
else:
w_end_overlap = 0
real_end_w = end_w
latent_slice = latent[:, :, :, real_start_h:real_end_h,
real_start_w:real_end_w]
recon = decode(first_stage_model, latent_slice)
recon = recon.permute(3, 0, 1, 2).contiguous()[None]
recon = recon[:, :, :,
h_start_overlap:recon.shape[3] - h_end_overlap,
w_start_overlap:recon.shape[4] - w_end_overlap]
ori_video[:, :, :, start_h * 8:end_h * 8,
start_w * 8:end_w * 8] = recon
ori_video = ori_video.squeeze(0)
ori_video = ori_video.permute(1, 2, 3, 0).contiguous().cpu()
return ori_video
def prepare_input(text, model, T, negative_prompt=None, pos_prompt=None):
if negative_prompt is None:
negative_prompt = ''
if pos_prompt is None:
pos_prompt = ''
value_dict = {
'prompt': text + pos_prompt,
'negative_prompt': negative_prompt,
'num_frames': torch.tensor(T).unsqueeze(0),
}
print(value_dict)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict, [1])
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=['txt'],
)
for k in c:
if not k == 'crossattn':
c[k], uc[k] = map(lambda y: y[k][:math.prod([1])].to('cuda'),
(c, uc))
return c, uc
def save_memory_encode_first_stage(x, model):
splits_x = torch.split(x, [17, 16, 16], dim=2)
all_out = []
with torch.autocast('cuda', enabled=False):
for idx, input_x in enumerate(splits_x):
if idx == len(splits_x) - 1:
clear_fake_cp_cache = True
else:
clear_fake_cp_cache = False
out = model.first_stage_model.encode(
input_x.contiguous(), clear_fake_cp_cache=clear_fake_cp_cache)
all_out.append(out)
z = torch.cat(all_out, dim=2)
z = model.scale_factor * z
return z
def seed_everything(seed: int = 42):
import os
import random
import numpy as np
import torch
# Python random module
random.seed(seed)
# Numpy
np.random.seed(seed)
# PyTorch
torch.manual_seed(seed)
# If using CUDA
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
# # CuDNN
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# OS environment
os.environ['PYTHONHASHSEED'] = str(seed)
def get_time_slice_vae():
vae_config = {
'target': 'vae_modules.autoencoder.VideoAutoencoderInferenceWrapper',
'params': {
'cp_size': 1,
'ckpt_path': './checkpoints/3d-vae.pt',
'ignore_keys': ['loss'],
'loss_config': {
'target': 'torch.nn.Identity'
},
'regularizer_config': {
'target':
'vae_modules.regularizers.DiagonalGaussianRegularizer'
},
'encoder_config': {
'target':
'vae_modules.cp_enc_dec.SlidingContextParallelEncoder3D',
'params': {
'double_z': True,
'z_channels': 16,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 2, 4],
'attn_resolutions': [],
'num_res_blocks': 3,
'dropout': 0.0,
'gather_norm': False
}
},
'decoder_config': {
'target': 'vae_modules.cp_enc_dec.ContextParallelDecoder3D',
'params': {
'double_z': True,
'z_channels': 16,
'resolution': 256,
'in_channels': 3,
'out_ch': 3,
'ch': 128,
'ch_mult': [1, 2, 2, 4],
'attn_resolutions': [],
'num_res_blocks': 3,
'dropout': 0.0,
'gather_norm': False
}
}
}
}
vae = instantiate_from_config(vae_config).eval().half().cuda()
return vae
import math
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from packaging import version
from torch import nn
if version.parse(torch.__version__) >= version.parse('2.0.0'):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
BACKEND_MAP = {
SDPBackend.MATH: {
'enable_math': True,
'enable_flash': False,
'enable_mem_efficient': False,
},
SDPBackend.FLASH_ATTENTION: {
'enable_math': False,
'enable_flash': True,
'enable_mem_efficient': False,
},
SDPBackend.EFFICIENT_ATTENTION: {
'enable_math': False,
'enable_flash': False,
'enable_mem_efficient': True,
},
None: {
'enable_math': True,
'enable_flash': True,
'enable_mem_efficient': True
},
}
else:
from contextlib import nullcontext
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
print(
f'No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, '
f'you are using PyTorch {torch.__version__}. You might want to consider upgrading.'
)
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
from modules.utils import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
backend=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.backend = backend
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
h = self.heads
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
k = repeat(k[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_cp)
v = repeat(v[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_cp)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
## old
"""
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
"""
## new
with sdp_kernel(**BACKEND_MAP[self.backend]):
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask) # scale is dim_head ** -0.5 per default
del q, k, v
out = rearrange(out, 'b h n d -> b n (h d)', h=h)
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
**kwargs):
super().__init__()
print(
f'Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using '
f'{heads} heads with a dimension of {dim_head}.')
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(
self,
x,
context=None,
mask=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0,
):
if additional_tokens is not None:
# get the number of masked tokens at the beginning of the output sequence
n_tokens_to_mask = additional_tokens.shape[1]
# add additional token
x = torch.cat([additional_tokens, x], dim=1)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
if n_times_crossframe_attn_in_self:
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
k = repeat(
k[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_times_crossframe_attn_in_self,
)
v = repeat(
v[::n_times_crossframe_attn_in_self],
'b ... -> (b n) ...',
n=n_times_crossframe_attn_in_self,
)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=self.attention_op)
# TODO: Use this directly in the attention operation, as a bias
if exists(mask):
raise NotImplementedError
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
if additional_tokens is not None:
# remove additional token
out = out[:, n_tokens_to_mask:]
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention, # vanilla attention
'softmax-xformers': MemoryEfficientCrossAttention, # ampere
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attn_mode='softmax',
sdp_backend=None,
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != 'softmax' and not XFORMERS_IS_AVAILABLE:
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f'This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}'
)
attn_mode = 'softmax'
elif attn_mode == 'softmax' and not SDP_IS_AVAILABLE:
print(
'We do not support vanilla attention anymore, as it is too expensive. Sorry.'
)
if not XFORMERS_IS_AVAILABLE:
assert False, "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print('Falling back to xformers efficient attention.')
attn_mode = 'softmax-xformers'
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse('2.0.0'):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
else:
assert sdp_backend is None
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None,
backend=sdp_backend,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
backend=sdp_backend,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
print(f'{self.__class__.__name__} is using checkpointing')
def forward(self,
x,
context=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0):
kwargs = {'x': x}
if context is not None:
kwargs.update({'context': context})
if additional_tokens is not None:
kwargs.update({'additional_tokens': additional_tokens})
if n_times_crossframe_attn_in_self:
kwargs.update({
'n_times_crossframe_attn_in_self':
n_times_crossframe_attn_in_self
})
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self,
x,
context=None,
additional_tokens=None,
n_times_crossframe_attn_in_self=0):
x = (self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
if not self.disable_self_attn else 0,
) + x)
x = self.attn2(self.norm2(x),
context=context,
additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
return x
class BasicTransformerSingleLayerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention, # vanilla attention
'softmax-xformers':
MemoryEfficientCrossAttention, # on the A100s not quite as fast as the above version
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
}
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
attn_mode='softmax',
):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim,
)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context) + x
x = self.ff(self.norm2(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
disable_self_attn=False,
use_linear=False,
attn_type='softmax',
use_checkpoint=True,
# sdp_backend=SDPBackend.FLASH_ATTENTION
sdp_backend=None,
):
super().__init__()
print(
f'constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads'
)
from omegaconf import ListConfig
if exists(context_dim) and not isinstance(context_dim,
(list, ListConfig)):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
print(
f'WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, '
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
assert all(
map(lambda x: x == context_dim[0], context_dim)
), 'need homogenous context_dim to match depth automatically'
context_dim = depth * [context_dim[0]]
elif context_dim is None:
context_dim = [None] * depth
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
attn_mode=attn_type,
checkpoint=use_checkpoint,
sdp_backend=sdp_backend,
) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
if i > 0 and len(context) == 1:
i = 0 # use same context for each block
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
import logging
import math
import random
import re
from abc import abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pytorch_lightning as pl
import torch
import torch.distributed
import torch.nn as nn
from einops import rearrange
from packaging import version
from sgm.util import (default, get_context_parallel_group,
get_context_parallel_group_rank, get_obj_from_str,
initialize_context_parallel, instantiate_from_config,
is_context_parallel_initialized)
from vae_modules.cp_enc_dec import _conv_gather, _conv_split
from vae_modules.ema import LitEma
logpy = logging.getLogger(__name__)
import os
class AbstractAutoencoder(pl.LightningModule):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
"""
def __init__(
self,
ema_decay: Union[None, float] = None,
monitor: Union[None, str] = None,
input_key: str = 'jpg',
):
super().__init__()
self.input_key = input_key
self.use_ema = ema_decay is not None
if monitor is not None:
self.monitor = monitor
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(
f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if version.parse(torch.__version__) >= version.parse('2.0.0'):
self.automatic_optimization = False
# def apply_ckpt(self, ckpt: Union[None, str, dict]):
# if ckpt is None:
# return
# if isinstance(ckpt, str):
# ckpt = {
# "target": "sgm.modules.checkpoint.CheckpointEngine",
# "params": {"ckpt_path": ckpt},
# }
# engine = instantiate_from_config(ckpt)
# engine(self)
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
if os.environ.get('SKIP_LOAD', False):
print(f'skip loading from {path}')
return
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print(f'Deleting key {k} from state_dict.')
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print('Missing keys: ', missing_keys)
print('Unexpected keys: ', unexpected_keys)
print(f'Restored from {path}')
@abstractmethod
def get_input(self, batch) -> Any:
raise NotImplementedError()
def on_train_batch_end(self, *args, **kwargs):
# for EMA computation
if self.use_ema:
self.model_ema(self)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logpy.info(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logpy.info(f'{context}: Restored training weights')
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
'encode()-method of abstract base class called')
@abstractmethod
def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError(
'decode()-method of abstract base class called')
def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg['target'])(params,
lr=lr,
**cfg.get('params', dict()))
def configure_optimizers(self) -> Any:
raise NotImplementedError()
class AutoencodingEngine(AbstractAutoencoder):
"""
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
(we also restore them explicitly as special cases for legacy reasons).
Regularizations such as KL or VQ are moved to the regularizer class.
"""
def __init__(
self,
*args,
encoder_config: Dict,
decoder_config: Dict,
loss_config: Dict,
regularizer_config: Dict,
optimizer_config: Union[Dict, None] = None,
lr_g_factor: float = 1.0,
trainable_ae_params: Optional[List[List[str]]] = None,
ae_optimizer_args: Optional[List[dict]] = None,
trainable_disc_params: Optional[List[List[str]]] = None,
disc_optimizer_args: Optional[List[dict]] = None,
disc_start_iter: int = 0,
diff_boost_factor: float = 3.0,
ckpt_engine: Union[None, str, dict] = None,
ckpt_path: Optional[str] = None,
additional_decode_keys: Optional[List[str]] = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.automatic_optimization = False # pytorch lightning
self.encoder = instantiate_from_config(encoder_config)
self.decoder = instantiate_from_config(decoder_config)
self.loss = instantiate_from_config(loss_config)
self.regularization = instantiate_from_config(regularizer_config)
self.optimizer_config = default(optimizer_config,
{'target': 'torch.optim.Adam'})
self.diff_boost_factor = diff_boost_factor
self.disc_start_iter = disc_start_iter
self.lr_g_factor = lr_g_factor
self.trainable_ae_params = trainable_ae_params
if self.trainable_ae_params is not None:
self.ae_optimizer_args = default(
ae_optimizer_args,
[{} for _ in range(len(self.trainable_ae_params))],
)
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
else:
self.ae_optimizer_args = [{}] # makes type consitent
self.trainable_disc_params = trainable_disc_params
if self.trainable_disc_params is not None:
self.disc_optimizer_args = default(
disc_optimizer_args,
[{} for _ in range(len(self.trainable_disc_params))],
)
assert len(self.disc_optimizer_args) == len(
self.trainable_disc_params)
else:
self.disc_optimizer_args = [{}] # makes type consitent
if ckpt_path is not None:
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
logpy.warn(
'Checkpoint path is deprecated, use `checkpoint_egnine` instead'
)
self.apply_ckpt(default(ckpt_path, ckpt_engine))
self.additional_decode_keys = set(default(additional_decode_keys, []))
def get_input(self, batch: Dict) -> torch.Tensor:
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in channels-first
# format (e.g., bchw instead if bhwc)
return batch[self.input_key]
def get_autoencoder_params(self) -> list:
params = []
if hasattr(self.loss, 'get_trainable_autoencoder_parameters'):
params += list(self.loss.get_trainable_autoencoder_parameters())
if hasattr(self.regularization, 'get_trainable_parameters'):
params += list(self.regularization.get_trainable_parameters())
params = params + list(self.encoder.parameters())
params = params + list(self.decoder.parameters())
return params
def get_discriminator_params(self) -> list:
if hasattr(self.loss, 'get_trainable_parameters'):
params = list(
self.loss.get_trainable_parameters()) # e.g., discriminator
else:
params = []
return params
def get_last_layer(self):
return self.decoder.get_last_layer()
def encode(self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
**kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
z = self.encoder(x, **kwargs)
if unregularized:
return z, dict()
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.decoder(z, **kwargs)
return x
def forward(
self, x: torch.Tensor, **additional_decode_kwargs
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x, return_reg_log=True)
dec = self.decode(z, **additional_decode_kwargs)
return z, dec, reg_log
def inner_training_step(self,
batch: dict,
batch_idx: int,
optimizer_idx: int = 0) -> torch.Tensor:
x = self.get_input(batch)
additional_decode_kwargs = {
key: batch[key]
for key in self.additional_decode_keys.intersection(batch)
}
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
if hasattr(self.loss, 'forward_keys'):
extra_info = {
'z': z,
'optimizer_idx': optimizer_idx,
'global_step': self.global_step,
'last_layer': self.get_last_layer(),
'split': 'train',
'regularization_log': regularization_log,
'autoencoder': self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
if optimizer_idx == 0:
# autoencode
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {'train/loss/rec': aeloss.detach()}
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True,
sync_dist=False,
)
self.log(
'loss',
aeloss.mean().detach(),
prog_bar=True,
logger=False,
on_epoch=False,
on_step=True,
)
return aeloss
elif optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
# -> discriminator always needs to return a tuple
self.log_dict(log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return discloss
else:
raise NotImplementedError(f'Unknown optimizer {optimizer_idx}')
def training_step(self, batch: dict, batch_idx: int):
opts = self.optimizers()
if not isinstance(opts, list):
# Non-adversarial case
opts = [opts]
optimizer_idx = batch_idx % len(opts)
if self.global_step < self.disc_start_iter:
optimizer_idx = 0
opt = opts[optimizer_idx]
opt.zero_grad()
with opt.toggle_model():
loss = self.inner_training_step(batch,
batch_idx,
optimizer_idx=optimizer_idx)
self.manual_backward(loss)
opt.step()
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch,
batch_idx,
postfix='_ema')
log_dict.update(log_dict_ema)
return log_dict
def _validation_step(self,
batch: dict,
batch_idx: int,
postfix: str = '') -> Dict:
x = self.get_input(batch)
z, xrec, regularization_log = self(x)
if hasattr(self.loss, 'forward_keys'):
extra_info = {
'z': z,
'optimizer_idx': 0,
'global_step': self.global_step,
'last_layer': self.get_last_layer(),
'split': 'val' + postfix,
'regularization_log': regularization_log,
'autoencoder': self,
}
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
else:
extra_info = dict()
out_loss = self.loss(x, xrec, **extra_info)
if isinstance(out_loss, tuple):
aeloss, log_dict_ae = out_loss
else:
# simple loss function
aeloss = out_loss
log_dict_ae = {f'val{postfix}/loss/rec': aeloss.detach()}
full_log_dict = log_dict_ae
if 'optimizer_idx' in extra_info:
extra_info['optimizer_idx'] = 1
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
full_log_dict.update(log_dict_disc)
self.log(
f'val{postfix}/loss/rec',
log_dict_ae[f'val{postfix}/loss/rec'],
sync_dist=True,
)
self.log_dict(full_log_dict, sync_dist=True)
return full_log_dict
def get_param_groups(
self, parameter_names: List[List[str]],
optimizer_args: List[dict]) -> Tuple[List[Dict[str, Any]], int]:
groups = []
num_params = 0
for names, args in zip(parameter_names, optimizer_args):
params = []
for pattern_ in names:
pattern_params = []
pattern = re.compile(pattern_)
for p_name, param in self.named_parameters():
if re.match(pattern, p_name):
pattern_params.append(param)
num_params += param.numel()
if len(pattern_params) == 0:
logpy.warn(
f'Did not find parameters for pattern {pattern_}')
params.extend(pattern_params)
groups.append({'params': params, **args})
return groups, num_params
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
if self.trainable_ae_params is None:
ae_params = self.get_autoencoder_params()
else:
ae_params, num_ae_params = self.get_param_groups(
self.trainable_ae_params, self.ae_optimizer_args)
logpy.info(
f'Number of trainable autoencoder parameters: {num_ae_params:,}'
)
if self.trainable_disc_params is None:
disc_params = self.get_discriminator_params()
else:
disc_params, num_disc_params = self.get_param_groups(
self.trainable_disc_params, self.disc_optimizer_args)
logpy.info(
f'Number of trainable discriminator parameters: {num_disc_params:,}'
)
opt_ae = self.instantiate_optimizer_from_config(
ae_params,
default(self.lr_g_factor, 1.0) * self.learning_rate,
self.optimizer_config,
)
opts = [opt_ae]
if len(disc_params) > 0:
opt_disc = self.instantiate_optimizer_from_config(
disc_params, self.learning_rate, self.optimizer_config)
opts.append(opt_disc)
return opts
@torch.no_grad()
def log_images(self,
batch: dict,
additional_log_kwargs: Optional[Dict] = None,
**kwargs) -> dict:
log = dict()
additional_decode_kwargs = {}
x = self.get_input(batch)
additional_decode_kwargs.update({
key: batch[key]
for key in self.additional_decode_keys.intersection(batch)
})
_, xrec, _ = self(x, **additional_decode_kwargs)
log['inputs'] = x
log['reconstructions'] = xrec
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
diff.clamp_(0, 1.0)
log['diff'] = 2.0 * diff - 1.0
# diff_boost shows location of small errors, by boosting their
# brightness.
log['diff_boost'] = 2.0 * torch.clamp(self.diff_boost_factor * diff,
0.0, 1.0) - 1
if hasattr(self.loss, 'log_images'):
log.update(self.loss.log_images(x, xrec))
with self.ema_scope():
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
log['reconstructions_ema'] = xrec_ema
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
diff_ema.clamp_(0, 1.0)
log['diff_ema'] = 2.0 * diff_ema - 1.0
log['diff_boost_ema'] = 2.0 * torch.clamp(
self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
if additional_log_kwargs:
additional_decode_kwargs.update(additional_log_kwargs)
_, xrec_add, _ = self(x, **additional_decode_kwargs)
log_str = 'reconstructions-' + '-'.join([
f'{key}={additional_log_kwargs[key]}'
for key in additional_log_kwargs
])
log[log_str] = xrec_add
return log
class AutoencodingEngineLegacy(AutoencodingEngine):
def __init__(self, embed_dim: int, **kwargs):
self.max_batch_size = kwargs.pop('max_batch_size', None)
ddconfig = kwargs.pop('ddconfig')
ckpt_path = kwargs.pop('ckpt_path', None)
ckpt_engine = kwargs.pop('ckpt_engine', None)
super().__init__(
encoder_config={
'target': 'sgm.modules.diffusionmodules.model.Encoder',
'params': ddconfig,
},
decoder_config={
'target': 'sgm.modules.diffusionmodules.model.Decoder',
'params': ddconfig,
},
**kwargs,
)
self.quant_conv = torch.nn.Conv2d(
(1 + ddconfig['double_z']) * ddconfig['z_channels'],
(1 + ddconfig['double_z']) * embed_dim,
1,
)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
self.apply_ckpt(default(ckpt_path, ckpt_engine))
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
z = self.quant_conv(z)
else:
N = x.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
z = list()
for i_batch in range(n_batches):
z_batch = self.encoder(x[i_batch * bs:(i_batch + 1) * bs])
z_batch = self.quant_conv(z_batch)
z.append(z_batch)
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)
else:
N = z.shape[0]
bs = self.max_batch_size
n_batches = int(math.ceil(N / bs))
dec = list()
for i_batch in range(n_batches):
dec_batch = self.post_quant_conv(z[i_batch * bs:(i_batch + 1) *
bs])
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
dec.append(dec_batch)
dec = torch.cat(dec, 0)
return dec
class AutoencoderKL(AutoencodingEngineLegacy):
def __init__(self, **kwargs):
if 'lossconfig' in kwargs:
kwargs['loss_config'] = kwargs.pop('lossconfig')
super().__init__(
regularizer_config={
'target': ('sgm.modules.autoencoding.regularizers'
'.DiagonalGaussianRegularizer')
},
**kwargs,
)
class IdentityFirstStage(AbstractAutoencoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def get_input(self, x: Any) -> Any:
return x
def encode(self, x: Any, *args, **kwargs) -> Any:
return x
def decode(self, x: Any, *args, **kwargs) -> Any:
return x
import os
class VideoAutoencodingEngine(AutoencodingEngine):
def __init__(
self,
ckpt_path: Union[None, str] = None,
ignore_keys: Union[Tuple, list] = (),
image_video_weights=[1, 1],
only_train_decoder=False,
context_parallel_size=0,
**kwargs,
):
super().__init__(**kwargs)
self.context_parallel_size = context_parallel_size
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def log_videos(self,
batch: dict,
additional_log_kwargs: Optional[Dict] = None,
**kwargs) -> dict:
return self.log_images(batch, additional_log_kwargs, **kwargs)
def get_input(self, batch: dict) -> torch.Tensor:
if self.context_parallel_size > 0:
if not is_context_parallel_initialized():
initialize_context_parallel(self.context_parallel_size)
batch = batch[self.input_key]
global_src_rank = get_context_parallel_group_rank(
) * self.context_parallel_size
torch.distributed.broadcast(batch,
src=global_src_rank,
group=get_context_parallel_group())
batch = _conv_split(batch, dim=2, kernel_size=1)
return batch
return batch[self.input_key]
def apply_ckpt(self, ckpt: Union[None, str, dict]):
if ckpt is None:
return
self.init_from_ckpt(ckpt)
def init_from_ckpt(self, path, ignore_keys=list()):
if os.environ.get('SKIP_LOAD', False):
print(f'skip loading from {path}')
return
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print(f'Deleting key {k} from state_dict.')
del sd[k]
missing_keys, unexpected_keys = self.load_state_dict(sd, strict=False)
print('Missing keys: ', missing_keys)
print('Unexpected keys: ', unexpected_keys)
print(f'Restored from {path}')
class VideoAutoencoderInferenceWrapper(VideoAutoencodingEngine):
def __init__(
self,
cp_size=0,
*args,
**kwargs,
):
self.cp_size = cp_size
return super().__init__(*args, **kwargs)
def encode(
self,
x: torch.Tensor,
return_reg_log: bool = False,
unregularized: bool = False,
input_cp: bool = False,
output_cp: bool = False,
**kwargs,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.cp_size > 0 and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(x,
src=global_src_rank,
group=get_context_parallel_group())
x = _conv_split(x, dim=2, kernel_size=1)
if return_reg_log:
z, reg_log = super().encode(x, return_reg_log, unregularized,
**kwargs)
else:
z = super().encode(x, return_reg_log, unregularized, **kwargs)
if self.cp_size > 0 and not output_cp:
z = _conv_gather(z, dim=2, kernel_size=1)
if return_reg_log:
return z, reg_log
return z
def decode(
self,
z: torch.Tensor,
input_cp: bool = False,
output_cp: bool = False,
split_kernel_size: int = 1,
**kwargs,
):
if self.cp_size > 0 and not input_cp:
if not is_context_parallel_initialized:
initialize_context_parallel(self.cp_size)
global_src_rank = get_context_parallel_group_rank() * self.cp_size
torch.distributed.broadcast(z,
src=global_src_rank,
group=get_context_parallel_group())
z = _conv_split(z, dim=2, kernel_size=split_kernel_size)
x = super().decode(z, **kwargs)
if self.cp_size > 0 and not output_cp:
x = _conv_gather(x, dim=2, kernel_size=split_kernel_size)
return x
def forward(
self,
x: torch.Tensor,
input_cp: bool = False,
latent_cp: bool = False,
output_cp: bool = False,
**additional_decode_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
z, reg_log = self.encode(x,
return_reg_log=True,
input_cp=input_cp,
output_cp=latent_cp)
dec = self.decode(z,
input_cp=latent_cp,
output_cp=output_cp,
**additional_decode_kwargs)
return z, dec, reg_log
import math
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype
from beartype.typing import List, Optional, Tuple, Union
from einops import rearrange
from sgm.util import (get_context_parallel_group,
get_context_parallel_group_rank,
get_context_parallel_rank,
get_context_parallel_world_size)
# try:
from vae_modules.utils import SafeConv3d as Conv3d
# except:
# # Degrade to normal Conv3d if SafeConv3d is not available
# from torch.nn import Conv3d
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t, ) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def exists(v):
return v is not None
def pair(t):
return t if isinstance(t, tuple) else (t, t)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def leaky_relu(p=0.1):
return nn.LeakyReLU(p)
def _split(input_, dim):
cp_world_size = get_context_parallel_world_size()
if cp_world_size == 1:
return input_
cp_rank = get_context_parallel_rank()
# print('in _split, cp_rank:', cp_rank, 'input_size:', input_.shape)
inpu_first_frame_ = input_.transpose(0,
dim)[:1].transpose(0,
dim).contiguous()
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
dim_size = input_.size()[dim] // cp_world_size
input_list = torch.split(input_, dim_size, dim=dim)
output = input_list[cp_rank]
if cp_rank == 0:
output = torch.cat([inpu_first_frame_, output], dim=dim)
output = output.contiguous()
# print('out _split, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _gather(input_, dim):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_frame_ = input_.transpose(0,
dim)[:1].transpose(0,
dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[1:].transpose(0, dim).contiguous()
tensor_list = [
torch.empty_like(torch.cat([input_first_frame_, input_], dim=dim))
] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
if cp_rank == 0:
input_ = torch.cat([input_first_frame_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _gather, cp_rank:', cp_rank, 'output_size:', output.shape)
return output
def _conv_split(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
# print('in _conv_split, cp_rank:', cp_rank, 'input_size:', input_.shape)
cp_rank = get_context_parallel_rank()
dim_size = (input_.size()[dim] - kernel_size) // cp_world_size
if cp_rank == 0:
output = input_.transpose(dim, 0)[:dim_size + kernel_size].transpose(
dim, 0)
else:
# output = input_.transpose(dim, 0)[cp_rank * dim_size + 1:(cp_rank + 1) * dim_size + kernel_size].transpose(dim, 0)
output = input_.transpose(
dim, 0)[cp_rank * dim_size + kernel_size:(cp_rank + 1) * dim_size +
kernel_size].transpose(dim, 0)
output = output.contiguous()
# print('out _conv_split, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _conv_gather(input_, dim, kernel_size):
cp_world_size = get_context_parallel_world_size()
# Bypass the function if context parallel is 1
if cp_world_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
# print('in _conv_gather, cp_rank:', cp_rank, 'input_size:', input_.shape)
input_first_kernel_ = input_.transpose(0, dim)[:kernel_size].transpose(
0, dim).contiguous()
if cp_rank == 0:
input_ = input_.transpose(0, dim)[kernel_size:].transpose(
0, dim).contiguous()
else:
input_ = input_.transpose(0, dim)[max(kernel_size - 1, 0):].transpose(
0, dim).contiguous()
tensor_list = [
torch.empty_like(torch.cat([input_first_kernel_, input_], dim=dim))
] + [torch.empty_like(input_) for _ in range(cp_world_size - 1)]
if cp_rank == 0:
input_ = torch.cat([input_first_kernel_, input_], dim=dim)
tensor_list[cp_rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()
# print('out _conv_gather, cp_rank:', cp_rank, 'input_size:', output.shape)
return output
def _pass_from_previous_rank(input_, dim, kernel_size):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size +
1:].contiguous(),
send_rank,
group=group)
if cp_rank > 0:
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_], dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
# print('out _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
return input_
def _fake_cp_pass_from_previous_rank(input_,
dim,
kernel_size,
cache_padding=None):
# Bypass the function if kernel size is 1
if kernel_size == 1:
return input_
group = get_context_parallel_group()
cp_rank = get_context_parallel_rank()
cp_group_rank = get_context_parallel_group_rank()
cp_world_size = get_context_parallel_world_size()
# print('in _pass_from_previous_rank, cp_rank:', cp_rank, 'input_size:', input_.shape)
global_rank = torch.distributed.get_rank()
global_world_size = torch.distributed.get_world_size()
input_ = input_.transpose(0, dim)
# pass from last rank
send_rank = global_rank + 1
recv_rank = global_rank - 1
if send_rank % cp_world_size == 0:
send_rank -= cp_world_size
if recv_rank % cp_world_size == cp_world_size - 1:
recv_rank += cp_world_size
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
# req_recv = torch.distributed.recv(recv_buffer, recv_rank, group=group)
# req_recv.wait()
recv_buffer = torch.empty_like(input_[-kernel_size + 1:]).contiguous()
if cp_rank < cp_world_size - 1:
req_send = torch.distributed.isend(input_[-kernel_size +
1:].contiguous(),
send_rank,
group=group)
if cp_rank > 0:
req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
# req_send = torch.distributed.isend(input_[-kernel_size + 1:].contiguous(), send_rank, group=group)
# req_recv = torch.distributed.irecv(recv_buffer, recv_rank, group=group)
if cp_rank == 0:
if cache_padding is not None:
input_ = torch.cat(
[cache_padding.transpose(0, dim).to(input_.device), input_],
dim=0)
else:
input_ = torch.cat([input_[:1]] * (kernel_size - 1) + [input_],
dim=0)
else:
req_recv.wait()
input_ = torch.cat([recv_buffer, input_], dim=0)
input_ = input_.transpose(0, dim).contiguous()
return input_
def _drop_from_previous_rank(input_, dim, kernel_size):
input_ = input_.transpose(0, dim)[kernel_size - 1:].transpose(0, dim)
return input_
class _ConvolutionScatterToContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_split(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_gather(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionGatherFromContextParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _conv_gather(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _conv_split(grad_output, ctx.dim, ctx.kernel_size), None, None
class _ConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _pass_from_previous_rank(input_, dim, kernel_size)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim,
ctx.kernel_size), None, None
class _FakeCPConvolutionPassFromPreviousRank(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, kernel_size, cache_padding):
ctx.dim = dim
ctx.kernel_size = kernel_size
return _fake_cp_pass_from_previous_rank(input_, dim, kernel_size,
cache_padding)
@staticmethod
def backward(ctx, grad_output):
return _drop_from_previous_rank(grad_output, ctx.dim,
ctx.kernel_size), None, None, None
def conv_scatter_to_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionScatterToContextParallelRegion.apply(
input_, dim, kernel_size)
def conv_gather_from_context_parallel_region(input_, dim, kernel_size):
return _ConvolutionGatherFromContextParallelRegion.apply(
input_, dim, kernel_size)
def conv_pass_from_last_rank(input_, dim, kernel_size):
return _ConvolutionPassFromPreviousRank.apply(input_, dim, kernel_size)
def fake_cp_pass_from_previous_rank(input_, dim, kernel_size, cache_padding):
return _FakeCPConvolutionPassFromPreviousRank.apply(
input_, dim, kernel_size, cache_padding)
class ContextParallelCausalConv3d(nn.Module):
def __init__(self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride=1,
**kwargs):
super().__init__()
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
time_pad = time_kernel_size - 1
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_kernel_size = time_kernel_size
self.temporal_dim = 2
stride = (stride, stride, stride)
dilation = (1, 1, 1)
self.conv = Conv3d(chan_in,
chan_out,
kernel_size,
stride=stride,
dilation=dilation,
**kwargs)
self.cache_padding = None
def forward(self, input_, clear_cache=True):
# if input_.shape[2] == 1: # handle image
# # first frame padding
# input_parallel = torch.cat([input_] * self.time_kernel_size, dim=2)
# else:
# input_parallel = conv_pass_from_last_rank(input_, self.temporal_dim, self.time_kernel_size)
# padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
# input_parallel = F.pad(input_parallel, padding_2d, mode = 'constant', value = 0)
# output_parallel = self.conv(input_parallel)
# output = output_parallel
# return output
input_parallel = fake_cp_pass_from_previous_rank(
input_, self.temporal_dim, self.time_kernel_size,
self.cache_padding)
del self.cache_padding
self.cache_padding = None
if not clear_cache:
cp_rank, cp_world_size = get_context_parallel_rank(
), get_context_parallel_world_size()
global_rank = torch.distributed.get_rank()
if cp_world_size == 1:
self.cache_padding = (
input_parallel[:, :, -self.time_kernel_size +
1:].contiguous().detach().clone().cpu())
else:
if cp_rank == cp_world_size - 1:
torch.distributed.isend(
input_parallel[:, :, -self.time_kernel_size +
1:].contiguous(),
global_rank + 1 - cp_world_size,
group=get_context_parallel_group(),
)
if cp_rank == 0:
recv_buffer = torch.empty_like(
input_parallel[:, :, -self.time_kernel_size +
1:]).contiguous()
torch.distributed.recv(recv_buffer,
global_rank - 1 + cp_world_size,
group=get_context_parallel_group())
self.cache_padding = recv_buffer.contiguous().detach(
).clone().cpu()
padding_2d = (self.width_pad, self.width_pad, self.height_pad,
self.height_pad)
input_parallel = F.pad(input_parallel,
padding_2d,
mode='constant',
value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
return output
class ContextParallelGroupNorm(torch.nn.GroupNorm):
def forward(self, input_):
gather_flag = input_.shape[2] > 1
if gather_flag:
input_ = conv_gather_from_context_parallel_region(input_,
dim=2,
kernel_size=1)
output = super().forward(input_)
if gather_flag:
output = conv_scatter_to_context_parallel_region(output,
dim=2,
kernel_size=1)
return output
def Normalize(in_channels, gather=False, **kwargs): # same for 3D and 2D
if gather:
return ContextParallelGroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
else:
return torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
class SpatialNorm3D(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
freeze_norm_layer=False,
add_conv=False,
pad_mode='constant',
gather=False,
**norm_layer_params,
):
super().__init__()
if gather:
self.norm_layer = ContextParallelGroupNorm(num_channels=f_channels,
**norm_layer_params)
else:
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels,
**norm_layer_params)
# self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if add_conv:
self.conv = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=zq_channels,
kernel_size=3,
)
self.conv_y = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
self.conv_b = ContextParallelCausalConv3d(
chan_in=zq_channels,
chan_out=f_channels,
kernel_size=1,
)
def forward(self, f, zq, clear_fake_cp_cache=True):
if hasattr(self, 'force_split') and self.force_split:
force_split = True
else:
force_split = False
if f.shape[2] > 1 and f.shape[2] % 2 == 1 or force_split:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = torch.nn.functional.interpolate(zq_first,
size=f_first_size,
mode='nearest')
zq_rest = torch.nn.functional.interpolate(zq_rest,
size=f_rest_size,
mode='nearest')
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = torch.nn.functional.interpolate(zq,
size=f.shape[-3:],
mode='nearest')
if self.add_conv:
zq = self.conv(zq, clear_cache=clear_fake_cp_cache)
# f = conv_gather_from_context_parallel_region(f, dim=2, kernel_size=1)
norm_f = self.norm_layer(f)
# norm_f = conv_scatter_to_context_parallel_region(norm_f, dim=2, kernel_size=1)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize3D(
in_channels,
zq_ch,
add_conv,
gather=False,
):
return SpatialNorm3D(
in_channels,
zq_ch,
gather=gather,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
# class Upsample3D(nn.Module):
# def __init__(
# self,
# in_channels,
# with_conv,
# compress_time=False,
# ):
# super().__init__()
# self.with_conv = with_conv
# if self.with_conv:
# self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
# self.compress_time = compress_time
# def forward(self, x):
# if hasattr(self, "force_split") and self.force_split:
# force_split = True
# else:
# force_split = False
# if self.compress_time and x.shape[2] > 1:
# if x.shape[2] % 2 == 1 or force_split:
# # split first frame
# x_first, x_rest = x[:, :, 0], x[:, :, 1:]
# x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0, mode="nearest")
# x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0, mode="nearest")
# x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
# else:
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# else:
# # only interpolate 2D
# t = x.shape[2]
# x = rearrange(x, "b c t h w -> (b t) c h w")
# x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
# x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
# if self.with_conv:
# t = x.shape[2]
# x = rearrange(x, "b c t h w -> (b t) c h w")
# x = self.conv(x)
# x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
# return x
class Upsample3D(nn.Module):
def __init__(
self,
in_channels,
with_conv,
compress_time=False,
):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
self.compress_time = compress_time
self.scale_factor = 2
def forward(self, x):
if hasattr(self, 'force_split') and self.force_split:
force_split = True
else:
force_split = False
if self.compress_time and x.shape[2] > 1:
if x.shape[2] % 2 == 1 or force_split:
# split first frame
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = torch.nn.functional.interpolate(
x_first, scale_factor=self.scale_factor, mode='nearest')
x_rest = torch.nn.functional.interpolate(
x_rest, scale_factor=self.scale_factor, mode='nearest')
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
else:
x = torch.nn.functional.interpolate(
x, scale_factor=self.scale_factor, mode='nearest')
else:
# only interpolate 2D
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = torch.nn.functional.interpolate(x,
scale_factor=self.scale_factor,
mode='nearest')
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.with_conv:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class DownSample3D(nn.Module):
def __init__(self,
in_channels,
with_conv,
compress_time=False,
out_channels=None):
super().__init__()
self.with_conv = with_conv
if out_channels is None:
out_channels = in_channels
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=0)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time and x.shape[2] > 1:
h, w = x.shape[-2:]
x = rearrange(x, 'b c t h w -> (b h w) c t')
if x.shape[-1] % 2 == 1:
# split first frame
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
try:
x_rest = torch.nn.functional.avg_pool1d(x_rest,
kernel_size=2,
stride=2)
except:
# for loop the avg_pool1d
print(
'######### for loop the avg_pool1d in else ###########'
)
x_rest_list = x_rest.split(len(x_rest) // 4, dim=0)
x_rest = torch.cat([
torch.nn.functional.avg_pool1d(
xr, kernel_size=2, stride=2)
for xr in x_rest_list
],
dim=0)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
else:
try:
x = torch.nn.functional.avg_pool1d(x,
kernel_size=2,
stride=2)
except: # for loop the avg_pool1d
print(
'######### for loop the avg_pool1d in else ###########'
)
x_list = x.split(len(x) // 4, dim=0)
x = torch.cat([
torch.nn.functional.avg_pool1d(
xr, kernel_size=2, stride=2) for xr in x_list
],
dim=0)
x = rearrange(x, '(b h w) c t -> b c t h w', h=h, w=w)
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.conv(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
else:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class ContextParallelResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
zq_ch=None,
add_conv=False,
gather_norm=False,
normalization=Normalize,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = normalization(
in_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.conv1 = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = normalization(
out_channels,
zq_ch=zq_ch,
add_conv=add_conv,
gather=gather_norm,
)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = ContextParallelCausalConv3d(
chan_in=out_channels,
chan_out=out_channels,
kernel_size=3,
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=out_channels,
kernel_size=3,
)
else:
self.nin_shortcut = Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x, temb, zq=None, clear_fake_cp_cache=True):
h = x
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm1(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else:
h = self.norm1(h)
# if isinstance(self.norm1, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv1(h, clear_cache=clear_fake_cp_cache)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
if zq is not None:
h = self.norm2(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
else:
h = self.norm2(h)
# if isinstance(self.norm2, torch.nn.GroupNorm):
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h, clear_cache=clear_fake_cp_cache)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, clear_cache=clear_fake_cp_cache)
else:
x = self.nin_shortcut(x)
return x + h
class ContextParallelEncoder3D(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
pad_mode='first',
temporal_compress_times=4,
gather_norm=False,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = ContextParallelCausalConv3d(
chan_in=in_channels,
chan_out=self.ch,
kernel_size=3,
)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
gather_norm=gather_norm,
))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = DownSample3D(block_in,
resamp_with_conv,
compress_time=True)
else:
down.downsample = DownSample3D(block_in,
resamp_with_conv,
compress_time=False)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
gather_norm=gather_norm,
)
# end
self.norm_out = Normalize(block_in, gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=2 * z_channels if double_z else z_channels,
kernel_size=3,
)
def forward(self, x, **kwargs):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class ContextParallelDecoder3D(nn.Module):
def __init__(
self,
*,
ch, # 128
out_ch, # 3
ch_mult=(1, 2, 4, 8),
num_res_blocks, # 3
attn_resolutions, # []
dropout=0.0, # 0.0
resamp_with_conv=True, # True
in_channels, # 3
resolution, # 256
z_channels, # 16
give_pre_end=False, # False
zq_ch=None, # None
add_conv=False,
pad_mode='first', # "first"
temporal_compress_times=4, # 4
gather_norm=False, # False
**ignorekwargs, # {'double_z': True}
):
super().__init__()
self.ch = ch # 128
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
chan_out=block_in,
kernel_size=3,
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
with_conv=resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
with_conv=resamp_with_conv,
compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in,
zq_ch,
add_conv=add_conv,
gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=out_ch,
kernel_size=3,
)
def forward(self, z, clear_fake_cp_cache=True, **kwargs):
self.last_z_shape = z.shape
# timestep embedding
temb = None
t = z.shape[2]
# z to block_in
zq = z
h = self.conv_in(z, clear_cache=clear_fake_cp_cache)
# middle
h = self.mid.block_1(h,
temb,
zq,
clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h,
temb,
zq,
clear_fake_cp_cache=clear_fake_cp_cache)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
h, temb, zq, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq, clear_fake_cp_cache=clear_fake_cp_cache)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h
def get_last_layer(self):
return self.conv_out.conv.weight
class SlidingContextParallelEncoder3D(ContextParallelEncoder3D):
def forward(self, x, clear_fake_cp_cache=True):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x, clear_cache=clear_fake_cp_cache)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](
h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
h = self.mid.block_2(h, temb, clear_fake_cp_cache=clear_fake_cp_cache)
# end
# h = conv_gather_from_context_parallel_region(h, dim=2, kernel_size=1)
h = self.norm_out(h)
# h = conv_scatter_to_context_parallel_region(h, dim=2, kernel_size=1)
h = nonlinearity(h)
h = self.conv_out(h, clear_cache=clear_fake_cp_cache)
return h
class LatentUpscaler(ContextParallelDecoder3D):
def __init__(
self,
*,
ch=128, # 128
out_ch=16, # 3
scale_factor=2, # 3
ch_mult=(2, 4), # (1, 2, 4, 8)
num_res_blocks=2, # 3
attn_resolutions=[], # []
dropout=0.0, # 0.0
resamp_with_conv=True, # True
in_channels=3, # 3
resolution=256, # 256
z_channels=16, # 16
give_pre_end=False, # False
zq_ch=None, # None
add_conv=False,
pad_mode='first', # "first"
temporal_compress_times=4, # 4
gather_norm=False, # False
double_z=True):
super(ContextParallelDecoder3D, self).__init__()
self.ch = ch # 128
self.temb_ch = 0
self.scale_factor = scale_factor
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
# log2 of temporal_compress_times
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1, ) + tuple(ch_mult)
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
self.conv_in = ContextParallelCausalConv3d(
chan_in=z_channels,
chan_out=block_in,
kernel_size=3,
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
self.mid.block_2 = ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ContextParallelResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
normalization=Normalize3D,
gather_norm=gather_norm,
))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = Upsample3D(block_in,
with_conv=resamp_with_conv,
compress_time=False)
else:
up.upsample = Upsample3D(block_in,
with_conv=resamp_with_conv,
compress_time=True)
self.up.insert(0, up)
self.norm_out = Normalize3D(block_in,
zq_ch,
add_conv=add_conv,
gather=gather_norm)
self.conv_out = ContextParallelCausalConv3d(
chan_in=block_in,
chan_out=out_ch,
kernel_size=3,
)
# for close Upsample3D compress_time
for n, m in self.named_modules():
if isinstance(m, Upsample3D):
m.compress_time = False
m.scale_factor = 2
# mini test latent upscaler
if __name__ == '__main__':
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
x = torch.randn(2, 16, 49, 60, 90).cuda()
# b,c,t,h,w
print(x.shape)
model = LatentUpscaler(
ch=128, # 128
out_ch=16, # 3
ch_mult=(2, 4), # (1, 2, 4, 8)
num_res_blocks=2, # 3
attn_resolutions=[], # []
dropout=0.0, # 0.0
resamp_with_conv=True, # True
in_channels=3, # 3
resolution=256, # 256
z_channels=16, # 16
give_pre_end=False, # False
zq_ch=None, # None
add_conv=False,
pad_mode='first', # "first"
temporal_compress_times=4, # 4
gather_norm=False, # False
double_z=True)
print(model)
out = model(x)
print(out.shape)
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int),
)
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,
(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
shadow_params[sname].sub_(
one_minus_decay *
(shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
from abc import abstractmethod
from typing import Any, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class DiagonalGaussianDistribution:
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
# x = self.mean + self.std * torch.randn(self.mean.shape).to(
# device=self.parameters.device
# )
x = self.mean + self.std * torch.randn_like(self.mean)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean
class AbstractRegularizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
raise NotImplementedError()
@abstractmethod
def get_trainable_parameters(self) -> Any:
raise NotImplementedError()
class IdentityRegularizer(AbstractRegularizer):
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, dict()
def get_trainable_parameters(self) -> Any:
yield from ()
def measure_perplexity(
predicted_indices: torch.Tensor,
num_centroids: int) -> Tuple[torch.Tensor, torch.Tensor]:
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
encodings = F.one_hot(predicted_indices,
num_centroids).float().reshape(-1, num_centroids)
avg_probs = encodings.mean(0)
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log['kl_loss'] = kl_loss
return z, log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment