"git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "7cb52324254f6b91dc9e71e4d45404597bc192a7"
Commit 1f6a467e authored by comfyanonymous's avatar comfyanonymous
Browse files

Update ldm dir with latest upstream stable diffusion changes.

parent 642516a3
...@@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak ...@@ -8,16 +8,17 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, mak
class DDIMSampler(object): class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
......
...@@ -1331,7 +1331,13 @@ class DiffusionWrapper(torch.nn.Module): ...@@ -1331,7 +1331,13 @@ class DiffusionWrapper(torch.nn.Module):
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
else: else:
cc = c_crossattn cc = c_crossattn
out = self.diffusion_model(x, t, context=cc) if hasattr(self, "scripted_diffusion_model"):
# TorchScript changes names of the arguments
# with argument cc defined as context=cc scripted model will produce
# an error: RuntimeError: forward() is missing value for argument 'argument_3'.
out = self.scripted_diffusion_model(x, t, cc)
else:
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid': elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1) xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1) cc = torch.cat(c_crossattn, 1)
......
...@@ -11,16 +11,17 @@ MODEL_TYPES = { ...@@ -11,16 +11,17 @@ MODEL_TYPES = {
class DPMSolverSampler(object): class DPMSolverSampler(object):
def __init__(self, model, **kwargs): def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
@torch.no_grad() @torch.no_grad()
......
...@@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding ...@@ -10,16 +10,17 @@ from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object): class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs): def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__() super().__init__()
self.model = model self.model = model
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.device = device
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"): if attr.device != self.device:
attr = attr.to(torch.device("cuda")) attr = attr.to(self.device)
setattr(self, name, attr) setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
......
...@@ -454,6 +454,7 @@ class UNetModel(nn.Module): ...@@ -454,6 +454,7 @@ class UNetModel(nn.Module):
num_classes=None, num_classes=None,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
use_bf16=False,
num_heads=-1, num_heads=-1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
...@@ -518,6 +519,7 @@ class UNetModel(nn.Module): ...@@ -518,6 +519,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32 self.dtype = th.float16 if use_fp16 else th.float32
self.dtype = th.bfloat16 if use_bf16 else self.dtype
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample self.num_heads_upsample = num_heads_upsample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment