Commit 30af93f2 authored by chenpangpang's avatar chenpangpang
Browse files

feat: gpu初始提交

parent 68e98ab8
Pipeline #2159 canceled with stages
import math
import numpy as np
import torch
from einops import repeat
def timestep_embedding(time_steps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param time_steps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=time_steps.device)
args = time_steps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(time_steps, "b -> b d", d=dim)
return embedding
def make_beta_schedule(
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear":
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
)
elif schedule == "cosine":
time_steps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = time_steps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt":
betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_time_steps(
ddim_discr_method, num_ddim_time_steps, num_ddpm_time_steps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_time_steps // num_ddim_time_steps
ddim_time_steps = np.asarray(list(range(0, num_ddpm_time_steps, c)))
steps_out = ddim_time_steps + 1
elif ddim_discr_method == "quad":
ddim_time_steps = (
(np.linspace(0, np.sqrt(num_ddpm_time_steps * 0.8), num_ddim_time_steps))
** 2
).astype(int)
steps_out = ddim_time_steps + 1
elif ddim_discr_method == "uniform_trailing":
c = num_ddpm_time_steps / num_ddim_time_steps
ddim_time_steps = np.flip(
np.round(np.arange(num_ddpm_time_steps, 0, -c))
).astype(np.int64)
steps_out = ddim_time_steps - 1
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_time_steps.shape[0] == num_ddim_time_steps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
if verbose:
print(f"Selected time_steps for ddim sampler: {steps_out}")
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_time_steps, eta, verbose=True):
# select alphas for computing the variance schedule
# print(f'ddim_time_steps={ddim_time_steps}, len_alphacums={len(alphacums)}')
alphas = alphacums[ddim_time_steps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_time_steps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose:
print(
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
)
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_time_steps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_time_steps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_time_steps):
t1 = i / num_diffusion_time_steps
t2 = (i + 1) / num_diffusion_time_steps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`numpy.ndarray`):
the betas that the scheduler is being initialized with.
Returns:
`numpy.ndarray`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
factor = guidance_rescale * (std_text / std_cfg) + (1 - guidance_rescale)
return noise_cfg * factor
This diff is collapsed.
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from core.common import gradient_checkpoint
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
print(f"XFORMERS_IS_AVAILBLE: {XFORMERS_IS_AVAILBLE}")
def get_group_norm_layer(in_channels):
if in_channels < 32:
if in_channels % 2 == 0:
num_groups = in_channels // 2
else:
num_groups = in_channels
else:
num_groups = 32
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
if dim_out is None:
dim_out = dim
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class SpatialTemporalAttention(nn.Module):
"""Uses xformers to implement efficient epipolar masking for cross-attention between views."""
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
if context_dim is None:
context_dim = query_dim
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
self.attention_op = None
def forward(self, x, context=None, enhance_multi_view_correspondence=False):
q = self.to_q(x)
if context is None:
context = x
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
if enhance_multi_view_correspondence:
with torch.no_grad():
normalized_x = torch.nn.functional.normalize(x.detach(), p=2, dim=-1)
cosine_sim_map = torch.bmm(normalized_x, normalized_x.transpose(-1, -2))
attn_bias = torch.where(cosine_sim_map > 0.0, 0.0, -1e9).to(
dtype=q.dtype
)
attn_bias = rearrange(
attn_bias.unsqueeze(1).expand(-1, self.heads, -1, -1),
"b h d1 d2 -> (b h) d1 d2",
).detach()
else:
attn_bias = None
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=attn_bias, op=self.attention_op
)
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
)
del q, k, v, attn_bias
return self.to_out(out)
class MultiViewSelfAttentionTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
gated_ff=True,
use_checkpoint=True,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
):
super().__init__()
attn_cls = SpatialTemporalAttention
# self.self_attention_only = self_attention_only
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=None,
) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
if enhance_multi_view_correspondence:
# Zero initalization when MVCorr is enabled.
zero_module_fn = zero_module
else:
def zero_module_fn(x):
return x
self.attn2 = zero_module_fn(
attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=None,
)
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint
self.full_spatial_temporal_attention = full_spatial_temporal_attention
self.enhance_multi_view_correspondence = enhance_multi_view_correspondence
def forward(self, x, time_steps=None):
return gradient_checkpoint(
self.many_stream_forward, (x, time_steps), None, flag=self.use_checkpoint
)
def many_stream_forward(self, x, time_steps=None):
n, v, hw = x.shape[:3]
x = rearrange(x, "n v hw c -> n (v hw) c")
x = (
self.attn1(
self.norm1(x), context=None, enhance_multi_view_correspondence=False
)
+ x
)
if not self.full_spatial_temporal_attention:
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
x = rearrange(x, "n v hw c -> (n v) hw c")
x = (
self.attn2(
self.norm2(x),
context=None,
enhance_multi_view_correspondence=self.enhance_multi_view_correspondence
and hw <= 256,
)
+ x
)
x = self.ff(self.norm3(x)) + x
if self.full_spatial_temporal_attention:
x = rearrange(x, "n (v hw) c -> n v hw c", v=v)
else:
x = rearrange(x, "(n v) hw c -> n v hw c", v=v)
return x
class MultiViewSelfAttentionTransformer(nn.Module):
"""Spatial Transformer block with post init to add cross attn."""
def __init__(
self,
in_channels,
n_heads,
d_head,
num_views,
depth=1,
dropout=0.0,
use_linear=True,
use_checkpoint=True,
zero_out_initialization=True,
full_spatial_temporal_attention=False,
enhance_multi_view_correspondence=False,
):
super().__init__()
self.num_views = num_views
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = get_group_norm_layer(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
MultiViewSelfAttentionTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
use_checkpoint=use_checkpoint,
full_spatial_temporal_attention=full_spatial_temporal_attention,
enhance_multi_view_correspondence=enhance_multi_view_correspondence,
)
for d in range(depth)
]
)
self.zero_out_initialization = zero_out_initialization
if zero_out_initialization:
_zero_func = zero_module
else:
def _zero_func(x):
return x
if not use_linear:
self.proj_out = _zero_func(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
else:
self.proj_out = _zero_func(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, time_steps=None):
# x: bt c h w
_, c, h, w = x.shape
n_views = self.num_views
x_in = x
x = self.norm(x)
x = rearrange(x, "(n v) c h w -> n v (h w) c", v=n_views)
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_in(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
for i, block in enumerate(self.transformer_blocks):
x = block(x, time_steps=time_steps)
if self.use_linear:
x = rearrange(x, "n v x c -> (n v) x c")
x = self.proj_out(x)
x = rearrange(x, "(n v) x c -> n v x c", v=n_views)
x = rearrange(x, "n v (h w) c -> (n v) c h w", h=h, w=w).contiguous()
return x + x_in
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import torch
def process_inference_batch(cfg_scale, batch, model, with_uncondition_extra=False):
for k in batch.keys():
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(model.device, dtype=model.dtype)
z, cond, x_rec = model.get_batch_input(
batch,
random_drop_training_conditions=False,
return_reconstructed_target_images=True,
)
# batch_size = x_rec.shape[0]
# Get unconditioned embedding for classifier-free guidance sampling
if cfg_scale != 1.0:
uc = model.get_unconditional_dict_for_sampling(batch, cond, x_rec)
else:
uc = None
if with_uncondition_extra:
uc_extra = model.get_unconditional_dict_for_sampling(
batch, cond, x_rec, is_extra=True
)
return cond, uc, uc_extra, x_rec
else:
return cond, uc, x_rec
This diff is collapsed.
pytorch_lightning
deepspeed
taming-transformers
scipy
einops
kornia
open_clip_torch
openai-clip
xformers
timm
av
gradio
\ No newline at end of file
FLAG_RUN_DEBUG = False
PATH_DIR_DEBUG = "./debug/"
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment