Commit 86bf0f82 authored by 0x3f3f3f3fun's avatar 0x3f3f3f3fun
Browse files

(1) fix a bug (#26). (2) upload inference code of latent image guidance. (3)...

(1) fix a bug (#26). (2) upload inference code of latent image guidance. (3) release real47 testset!
parent 30355a12
......@@ -173,7 +173,8 @@ class CrossAttention(nn.Module):
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
# with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type=str(x.device)):
with torch.autocast(enabled=False, device_type=x.device):
# with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
......
from typing import overload
from typing import overload, Optional
import torch
from torch.nn import functional as F
class Guidance:
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type):
def __init__(
self,
scale: float,
t_start: int,
t_stop: int,
space: str,
repeat: int
) -> "Guidance":
"""
Initialize latent image guidance.
Args:
scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
the closer the final result will be to the output of the first stage model.
t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
space (str): The data space for computing loss function (rgb or latent).
repeat (int): Repeat gradient descent for `repeat` times.
Our latent image guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
Thanks for their work!
"""
self.scale = scale
self.type = type
self.t_start = t_start
self.t_stop = t_stop
self.target = None
self.space = space
self.repeat = repeat
self.loss_type = loss_type
def load_target(self, target):
def load_target(self, target: torch.Tensor) -> torch.Tensor:
self.target = target
def __call__(self, target_x0, pred_x0, t):
def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Optional[torch.Tensor]:
if self.t_stop < t and t < self.t_start:
# print("sampling with classifier guidance")
# avoid propagating gradient out of this scope
......@@ -29,30 +48,31 @@ class Guidance:
return None
@overload
def _forward(self, target_x0, pred_x0): ...
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
...
class MSEGuidance(Guidance):
def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None:
def __init__(
self,
scale: float,
t_start: int,
t_stop: int,
space: str,
repeat: int
) -> "MSEGuidance":
super().__init__(
scale, type, t_start, t_stop, space, repeat, loss_type
scale, t_start, t_stop, space, repeat
)
@torch.enable_grad()
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor):
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor) -> torch.Tensor:
# inputs: [-1, 1], nchw, rgb
pred_x0.requires_grad_(True)
if self.loss_type == "mse":
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
elif self.loss_type == "downsample_mse":
# FIXME: scale_factor should be 1/4, not 4
lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic")
lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic")
loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum()
else:
raise ValueError(self.loss_type)
# This is what we actually use.
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
print(f"loss = {loss.item()}")
return -torch.autograd.grad(loss, pred_x0)[0]
......@@ -258,6 +258,7 @@ class SpacedSampler:
# ----------------- compute gradient for x0 in latent space ----------------- #
target, pred = None, None
if cond_fn.space == "latent":
# This is what we actually use.
target = self.model.get_first_stage_encoding(
self.model.encode_first_stage(cond_fn.target.to(device))
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment