Commit 33f3023e authored by 1SAA's avatar 1SAA
Browse files

[hotfix] fix implement error in diffusers

parent 48d33b1b
...@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool: ...@@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
return False return False
def _has_grad_tensor(obj) -> bool:
if isinstance(obj, tuple) or isinstance(obj, list):
for x in obj:
if _has_grad_tensor(x):
return True
return False
elif isinstance(obj, dict):
for x in obj.values():
if _has_grad_tensor(x):
return True
return False
else:
return _is_grad_tensor(obj)
def _get_grad_args(*args): def _get_grad_args(*args):
# if there is no grad tensors, do nothing
if not _has_grad_tensor(args):
return args, None
# returns the identical args if there is a grad tensor # returns the identical args if there is a grad tensor
for obj in args: for obj in args:
if _is_grad_tensor(obj): if _is_grad_tensor(obj):
......
...@@ -7,27 +7,22 @@ ...@@ -7,27 +7,22 @@
# #
# thanks! # thanks!
import os
import math import math
import os
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from einops import repeat from einops import repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear": if schedule == "linear":
betas = ( betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2)
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine": elif schedule == "cosine":
timesteps = ( timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s)
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2) alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0] alphas = alphas / alphas[0]
...@@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, ...@@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt": elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
...@@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep ...@@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
...@@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag): ...@@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if flag: if flag:
args = tuple(inputs) + tuple(params) from torch.utils.checkpoint import checkpoint as torch_checkpoint
return CheckpointFunction.apply(func, len(inputs), *args) return torch_checkpoint(func, *inputs)
# args = tuple(inputs) + tuple(params)
# return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
return func(*inputs) return func(*inputs)
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, length, *args): def forward(ctx, run_function, length, *args):
ctx.run_function = run_function ctx.run_function = run_function
ctx.input_tensors = list(args[:length]) ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:]) ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), ctx.gpu_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(), "dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()} "cache_enabled": torch.is_autocast_cache_enabled()
}
with torch.no_grad(): with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors) output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors return output_tensors
...@@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): ...@@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
""" """
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) /
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half half).to(device=timesteps.device)
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:
...@@ -211,14 +210,17 @@ def normalization(channels): ...@@ -211,14 +210,17 @@ def normalization(channels):
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module): class SiLU(nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
""" """
Create a 1D, 2D, or 3D convolution module. Create a 1D, 2D, or 3D convolution module.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment