Commit 0e56f303 authored by mashun's avatar mashun
Browse files

pyramid-flow

parents
Pipeline #2007 canceled with stages
import functools
import torch.nn as nn
from einops import rearrange
import torch
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
nn.init.constant_(m.bias.data, 0)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator as in Pix2Pix
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
"""
def __init__(self, input_nc=3, ndf=64, n_layers=4):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
# norm_layer = nn.BatchNorm2d
norm_layer = nn.InstanceNorm2d
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
class NLayerDiscriminator3D(nn.Module):
"""Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
"""
Construct a 3D PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input volumes
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
use_actnorm (bool) -- flag to use actnorm instead of batchnorm
"""
super(NLayerDiscriminator3D, self).__init__()
# if not use_actnorm:
# norm_layer = nn.BatchNorm3d
# else:
# raise NotImplementedError("Not implemented.")
norm_layer = nn.InstanceNorm3d
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func != nn.BatchNorm3d
else:
use_bias = norm_layer != nn.BatchNorm3d
kw = 4
padw = 1
sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=(1,2,2), padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.main = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.main(input)
\ No newline at end of file
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor
from diffusers.models.attention_processor import SpatialNorm
from .modeling_block import (
UNetMidBlock2D,
CausalUNetMidBlock2D,
get_down_block,
get_up_block,
get_input_layer,
get_output_layer,
)
from .modeling_resnet import (
Downsample2D,
Upsample2D,
TemporalDownsample2x,
TemporalUpsample2x,
)
from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
@dataclass
class DecoderOutput(BaseOutput):
r"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The decoded output sample from the last layer of the model.
"""
sample: torch.FloatTensor
class CausalVaeEncoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
spatial_down_sample: Tuple[bool, ...] = (True,),
temporal_down_sample: Tuple[bool, ...] = (False,),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: Tuple[int, ...] = (2,),
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
block_dropout: Tuple[int, ...] = (0.0,),
mid_block_add_attention=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
add_spatial_downsample=spatial_down_sample[i],
add_temporal_downsample=temporal_down_sample[i],
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
dropout=block_dropout[i],
)
self.down_blocks.append(down_block)
# mid
self.mid_block = CausalUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
dropout=block_dropout[-1],
)
# out
self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3, stride=1)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, is_init_image,
temporal_chunk, use_reentrant=False
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, is_init_image,
temporal_chunk, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample, is_init_image, temporal_chunk)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, is_init_image, temporal_chunk)
else:
# down
for down_block in self.down_blocks:
sample = down_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
# middle
sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return sample
class CausalVaeDecoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
spatial_up_sample: Tuple[bool, ...] = (True,),
temporal_up_sample: Tuple[bool, ...] = (False,),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: Tuple[int, ...] = (2,),
norm_num_groups: int = 32,
act_fn: str = "silu",
mid_block_add_attention=True,
interpolate: bool = True,
block_dropout: Tuple[int, ...] = (0.0,),
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = CausalConv3d(
in_channels,
block_out_channels[-1],
kernel_size=3,
stride=1,
)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = CausalUNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
dropout=block_dropout[-1],
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block[i],
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_spatial_upsample=spatial_up_sample[i],
add_temporal_upsample=temporal_up_sample[i],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
resnet_time_scale_shift='default',
interpolate=interpolate,
dropout=block_dropout[i],
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = CausalGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, stride=1)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
is_init_image=True,
temporal_chunk=False,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
is_init_image=is_init_image,
temporal_chunk=temporal_chunk,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
is_init_image=is_init_image,
temporal_chunk=temporal_chunk,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample,
is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
else:
# middle
sample = self.mid_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk,)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return sample
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[2, 3, 4],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[2, 3, 4],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
\ No newline at end of file
import os
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from .modeling_lpips import LPIPS
from .modeling_discriminator import NLayerDiscriminator, NLayerDiscriminator3D, weights_init
class AdaptiveLossWeight:
def __init__(self, timestep_range=[0, 1], buckets=300, weight_range=[1e-7, 1e7]):
self.bucket_ranges = torch.linspace(timestep_range[0], timestep_range[1], buckets-1)
self.bucket_losses = torch.ones(buckets)
self.weight_range = weight_range
def weight(self, timestep):
indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep)
return (1/self.bucket_losses.to(timestep.device)[indices]).clamp(*self.weight_range)
def update_buckets(self, timestep, loss, beta=0.99):
indices = torch.searchsorted(self.bucket_ranges.to(timestep.device), timestep).cpu()
self.bucket_losses[indices] = self.bucket_losses[indices]*beta + loss.detach().cpu() * (1-beta)
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def vanilla_d_loss(logits_real, logits_fake):
d_loss = 0.5 * (
torch.mean(torch.nn.functional.softplus(-logits_real))
+ torch.mean(torch.nn.functional.softplus(logits_fake))
)
return d_loss
def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
return weight
class LPIPSWithDiscriminator(nn.Module):
def __init__(
self,
disc_start,
logvar_init=0.0,
kl_weight=1.0,
pixelloss_weight=1.0,
perceptual_weight=1.0,
lpips_ckpt='/home/jinyang06/models/vae/video_vae_baseline/vgg_lpips.pth',
# --- Discriminator Loss ---
disc_num_layers=4,
disc_in_channels=3,
disc_factor=1.0,
disc_weight=0.5,
disc_loss="hinge",
add_discriminator=True,
using_3d_discriminator=False,
):
super().__init__()
assert disc_loss in ["hinge", "vanilla"]
self.kl_weight = kl_weight
self.pixel_weight = pixelloss_weight
self.perceptual_loss = LPIPS(lpips_ckpt_path=lpips_ckpt).eval()
self.perceptual_weight = perceptual_weight
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
if add_discriminator:
disc_cls = NLayerDiscriminator3D if using_3d_discriminator else NLayerDiscriminator
self.discriminator = disc_cls(
input_nc=disc_in_channels, n_layers=disc_num_layers,
).apply(weights_init)
else:
self.discriminator = None
self.discriminator_iter_start = disc_start
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
self.disc_factor = disc_factor
self.discriminator_weight = disc_weight
self.using_3d_discriminator = using_3d_discriminator
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(
nll_loss, self.last_layer[0], retain_graph=True
)[0]
g_grads = torch.autograd.grad(
g_loss, self.last_layer[0], retain_graph=True
)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
def forward(
self,
inputs,
reconstructions,
posteriors,
optimizer_idx,
global_step,
split="train",
last_layer=None,
):
t = reconstructions.shape[2]
inputs = rearrange(inputs, "b c t h w -> (b t) c h w").contiguous()
reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w").contiguous()
if optimizer_idx == 0:
# rec_loss = torch.mean(torch.abs(inputs - reconstructions), dim=(1,2,3), keepdim=True)
rec_loss = torch.mean(F.mse_loss(inputs, reconstructions, reduction='none'), dim=(1,2,3), keepdim=True)
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs, reconstructions)
nll_loss = self.pixel_weight * rec_loss + self.perceptual_weight * p_loss
nll_loss = nll_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
kl_loss = torch.mean(kl_loss)
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
if disc_factor > 0.0:
if self.using_3d_discriminator:
reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
try:
d_weight = self.calculate_adaptive_weight(
nll_loss, g_loss, last_layer=last_layer
)
except RuntimeError:
assert not self.training
d_weight = torch.tensor(0.0)
else:
d_weight = torch.tensor(0.0)
g_loss = torch.tensor(0.0)
loss = (
weighted_nll_loss
+ self.kl_weight * kl_loss
+ d_weight * disc_factor * g_loss
)
log = {
"{}/total_loss".format(split): loss.clone().detach().mean(),
"{}/logvar".format(split): self.logvar.detach(),
"{}/kl_loss".format(split): kl_loss.detach().mean(),
"{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/perception_loss".format(split): p_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
"{}/g_loss".format(split): g_loss.detach().mean(),
}
return loss, log
if optimizer_idx == 1:
if self.using_3d_discriminator:
inputs = rearrange(inputs, '(b t) c h w -> b c t h w', t=t)
reconstructions = rearrange(reconstructions, '(b t) c h w -> b c t h w', t=t)
logits_real = self.discriminator(inputs.contiguous().detach())
logits_fake = self.discriminator(reconstructions.contiguous().detach())
disc_factor = adopt_weight(
self.disc_factor, global_step, threshold=self.discriminator_iter_start
)
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
log = {
"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
"{}/logits_real".format(split): logits_real.detach().mean(),
"{}/logits_fake".format(split): logits_fake.detach().mean(),
}
return d_loss, log
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
class LPIPS(nn.Module):
# Learned perceptual metric
def __init__(self, use_dropout=True, lpips_ckpt_path=None):
super().__init__()
self.lpips_ckpt_path = lpips_ckpt_path # replace with your lpips path
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_from_pretrained()
for param in self.parameters():
param.requires_grad = False
def load_from_pretrained(self):
ckpt = self.lpips_ckpt_path
assert ckpt is not None, "Please replace with your lpips path"
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
print("loaded pretrained LPIPS loss from {}".format(ckpt))
def forward(self, input, target):
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
""" A single linear layer which does a 1x1 conv """
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if (use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
def normalize_tensor(x,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
return x/(norm_factor+eps)
def spatial_average(x, keepdim=True):
return x.mean([2,3],keepdim=keepdim)
if __name__ == "__main__":
model = LPIPS().eval()
_ = torch.manual_seed(123)
img1 = (torch.rand(10, 3, 100, 100) * 2) - 1
img2 = (torch.rand(10, 3, 100, 100) * 2) - 1
print(model(img1, img2).shape)
# embed()
\ No newline at end of file
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.normalization import AdaGroupNorm
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from .modeling_causal_conv import CausalConv3d, CausalGroupNorm
class CausalResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
linear_cls = nn.Linear
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = CausalGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = CausalGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = CausalConv3d(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor = None,
is_init_image=True,
temporal_chunk=False,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
linear_cls = nn.Linear
conv_cls = nn.Conv3d
if groups_out is None:
groups_out = groups
if self.time_embedding_norm == "ada_group":
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm1 = SpatialNorm(in_channels, temb_channels)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.time_embedding_norm == "ada_group":
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
elif self.time_embedding_norm == "spatial":
self.norm2 = SpatialNorm(out_channels, temb_channels)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor = None,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm1(hidden_states, temb)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
hidden_states = self.norm2(hidden_states, temb)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class CausalDownsample2x(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = (1, 2, 2)
self.name = name
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool3d(kernel_size=stride, stride=stride)
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class Downsample2D(nn.Module):
"""A 2D downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
padding: int = 0,
name: str = "conv",
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = (1, 2, 2)
self.name = name
conv_cls = nn.Conv3d
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1, 1, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class TemporalDownsample2x(nn.Module):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
padding: int = 0,
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = (2, 1, 1)
conv_cls = nn.Conv3d
if use_conv:
conv = conv_cls(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
)
else:
raise NotImplementedError("Not implemented for temporal downsample without")
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
if hidden_states.shape[2] == 1:
# image
pad = (1, 1, 1, 1, 1, 1)
else:
# video
pad = (1, 1, 1, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
hidden_states = self.conv(hidden_states)
return hidden_states
class CausalTemporalDownsample2x(nn.Module):
"""A Temporal downsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
padding (`int`, default `1`):
padding for the convolution.
name (`str`, default `conv`):
name of the downsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
kernel_size=3,
bias=True,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
stride = (2, 1, 1)
conv_cls = nn.Conv3d
if use_conv:
conv = CausalConv3d(
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
)
else:
raise NotImplementedError("Not implemented for temporal downsample without")
self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, is_init_image=True, temporal_chunk=False) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
return hidden_states
class Upsample2D(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = None,
padding=1,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.name = name
self.interpolate = interpolate
conv_cls = nn.Conv3d
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels * 4, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv = conv
self.conv.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(
self,
hidden_states: torch.FloatTensor,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
return hidden_states
class CausalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
kernel_size: Optional[int] = 3,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.name = name
self.interpolate = interpolate
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
conv = CausalConv3d(self.channels, self.out_channels * 4, kernel_size=kernel_size, stride=1, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_init_image=True, temporal_chunk=False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
hidden_states = rearrange(hidden_states, 'b (c p1 p2) t h w -> b c t (h p1) (w p2)', p1=2, p2=2)
return hidden_states
class TemporalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
kernel_size: Optional[int] = None,
padding=1,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.interpolate = interpolate
conv_cls = nn.Conv3d
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
# depth to space operator
if kernel_size is None:
kernel_size = 3
conv = conv_cls(self.channels, self.out_channels * 2, kernel_size=kernel_size, padding=padding, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_image: bool = False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
t = hidden_states.shape[2]
hidden_states = self.conv(hidden_states)
hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (p t) h w', p=2)
if t == 1 and is_image:
hidden_states = hidden_states[:, :, 1:]
return hidden_states
class CausalTemporalUpsample2x(nn.Module):
"""A 2D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
name (`str`, default `conv`):
name of the upsampling 2D layer.
"""
def __init__(
self,
channels: int,
use_conv: bool = True,
out_channels: Optional[int] = None,
kernel_size: Optional[int] = 3,
bias=True,
interpolate=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.interpolate = interpolate
conv = None
if interpolate:
raise NotImplementedError("Not implemented for spatial upsample with interpolate")
else:
# depth to space operator
conv = CausalConv3d(self.channels, self.out_channels * 2, kernel_size=kernel_size, stride=1, bias=bias)
self.conv = conv
def forward(
self,
hidden_states: torch.FloatTensor,
is_init_image=True, temporal_chunk=False,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
t = hidden_states.shape[2]
hidden_states = self.conv(hidden_states, is_init_image=is_init_image, temporal_chunk=temporal_chunk)
hidden_states = rearrange(hidden_states, 'b (c p) t h w -> b c (t p) h w', p=2)
if is_init_image:
hidden_states = hidden_states[:, :, 1:]
return hidden_states
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment