Commit 528b1293 authored by anton-l's avatar anton-l
Browse files

make style

parents f23bb3e8 cbb19ee8
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"] model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids: for model_id in model_ids:
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
...@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline): ...@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1 # get actual t and t-1
train_step = inference_step_times[t] train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1 prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas # compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
...@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline): ...@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients # compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta coeff_1 = (
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() (alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward # model forward
with torch.no_grad(): with torch.no_grad():
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# !pip install diffusers # !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom" model_id = "fusing/ddpm-lsun-bedroom"
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddpm import DDPM
import PIL.Image
import numpy as np import numpy as np
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"] import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids: for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id) path = os.path.join("/home/patrick/images/hf", model_id)
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
...@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline): ...@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) (1 - self.noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.noise_scheduler.get_alpha(t))
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
* self.noise_scheduler.get_beta(t)
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
...@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline): ...@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = self.noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import (
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -22,7 +27,9 @@ config = CLIPTextConfig( ...@@ -22,7 +27,9 @@ config = CLIPTextConfig(
use_padding_embeddings=True, use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model hf_encoder = model.text_model
...@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False) ...@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, glide = GLIDE(
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler) text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
...@@ -18,10 +18,20 @@ import math ...@@ -18,10 +18,20 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig import tqdm
from diffusers import (
ClassifierFreeGuidanceScheduler,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
...@@ -34,14 +44,6 @@ from transformers.utils import ( ...@@ -34,14 +44,6 @@ from transformers.utils import (
) )
import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
##################### #####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module) # START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
##################### #####################
...@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline): ...@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler upscale_noise_scheduler: GlideDDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_unet=text_unet,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler text_noise_scheduler=text_noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
) )
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
...@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline): ...@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
== posterior_variance.shape[0] == posterior_variance.shape[0]
...@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline): ...@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts. # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997 upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise( image = (
self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator (batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp )
* upsample_temp
)
num_timesteps = len(self.upscale_noise_scheduler) num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( image_coeff = (
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([t] * image.shape[0], device=torch_device)
...@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline): ...@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, prev_variance = self.upscale_noise_scheduler.sample_variance(
generator=generator) t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance
......
import torch import torch
from diffusers import DiffusionPipeline
import PIL.Image import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -15,7 +17,7 @@ img = pipeline("a crayon drawing of a corgi", generator) ...@@ -15,7 +17,7 @@ img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL # process image to PIL
img = img.squeeze(0) img = img.squeeze(0)
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img) image_pil = PIL.Image.fromarray(img)
# save image # save image
......
...@@ -84,6 +84,7 @@ _deps = [ ...@@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"numpy", "numpy",
"pytest", "pytest",
"regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"torchvision", "torchvision",
...@@ -168,6 +169,7 @@ install_requires = [ ...@@ -168,6 +169,7 @@ install_requires = [
deps["filelock"], deps["filelock"],
deps["huggingface-hub"], deps["huggingface-hub"],
deps["numpy"], deps["numpy"],
deps["regex"],
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["torchvision"], deps["torchvision"],
......
...@@ -6,7 +6,7 @@ __version__ = "0.0.1" ...@@ -6,7 +6,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
......
...@@ -23,13 +23,13 @@ import os ...@@ -23,13 +23,13 @@ import os
import re import re
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from . import __version__
from .utils import ( from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
...@@ -37,9 +37,6 @@ from .utils import ( ...@@ -37,9 +37,6 @@ from .utils import (
) )
from . import __version__
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json") _re_configuration_file = re.compile(r"config\.(.*)\.json")
...@@ -95,9 +92,7 @@ class ConfigMixin: ...@@ -95,9 +92,7 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict( config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
...@@ -157,16 +152,16 @@ class ConfigMixin: ...@@ -157,16 +152,16 @@ class ConfigMixin:
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
"`use_auth_token=True`." " pass `use_auth_token=True`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " " this model name. Check the model page at"
"available revisions." f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
...@@ -174,14 +169,16 @@ class ConfigMixin: ...@@ -174,14 +169,16 @@ class ConfigMixin:
) )
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." " run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
raise EnvironmentError( raise EnvironmentError(
...@@ -195,9 +192,7 @@ class ConfigMixin: ...@@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
f"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return config_dict return config_dict
......
...@@ -3,29 +3,15 @@ ...@@ -3,29 +3,15 @@
# 2. run `make deps_table_update`` # 2. run `make deps_table_update``
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3", "black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0", "filelock": "filelock",
"dataclasses": "dataclasses", "flake8": "flake8>=3.8.3",
"datasets": "datasets", "huggingface-hub": "huggingface-hub",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"numpy": "numpy>=1.17", "numpy": "numpy",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchaudio": "torchaudio", "torchvision": "torchvision",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
} }
...@@ -23,7 +23,8 @@ from pathlib import Path ...@@ -23,7 +23,8 @@ from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import cached_download from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module): ...@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.") raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
)
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(
......
...@@ -17,6 +17,6 @@ ...@@ -17,6 +17,6 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .vqvae import VQModel from .vqvae import VQModel
...@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast ...@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam from torch.optim import Adam
from torch.utils import data from torch.utils import data
from torchvision import transforms, utils
from PIL import Image from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes # dataset classes
class Dataset(data.Dataset): class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__() super().__init__()
self.folder = folder self.folder = folder
self.image_size = image_size self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose([ self.transform = transforms.Compose(
[
transforms.Resize(image_size), transforms.Resize(image_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size), transforms.CenterCrop(image_size),
transforms.ToTensor() transforms.ToTensor(),
]) ]
)
def __len__(self): def __len__(self):
return len(self.paths) return len(self.paths)
...@@ -359,7 +362,7 @@ class Dataset(data.Dataset): ...@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class # trainer class
class EMA(): class EMA:
def __init__(self, beta): def __init__(self, beta):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
......
...@@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
transformer_dim=512 transformer_dim=512,
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
...@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.register( self.register(
in_channels=in_channels, in_channels=in_channels,
...@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
......
from inspect import isfunction
from abc import abstractmethod
import math import math
from abc import abstractmethod
from inspect import isfunction
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from einops import repeat, rearrange from einops import rearrange, repeat
except: except:
print("Einops is not installed") print("Einops is not installed")
pass pass
...@@ -16,12 +17,13 @@ except: ...@@ -16,12 +17,13 @@ except:
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
def exists(val): def exists(val):
return val is not None return val is not None
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
...@@ -53,20 +55,13 @@ class GEGLU(nn.Module): ...@@ -53,20 +55,13 @@ class GEGLU(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = default(dim_out, dim) dim_out = default(dim_out, dim)
project_in = nn.Sequential( project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
nn.Linear(dim, inner_dim),
nn.GELU() self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
...@@ -90,17 +85,17 @@ class LinearAttention(nn.Module): ...@@ -90,17 +85,17 @@ class LinearAttention(nn.Module):
super().__init__() super().__init__()
self.heads = heads self.heads = heads
hidden_dim = dim_head * heads hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1) self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x): def forward(self, x):
b, c, h, w = x.shape b, c, h, w = x.shape
qkv = self.to_qkv(x) qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
k = k.softmax(dim=-1) k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v) context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q) out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
return self.to_out(out) return self.to_out(out)
...@@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module): ...@@ -110,26 +105,10 @@ class SpatialSelfAttention(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels, self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
in_channels, self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
kernel_size=1, self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
stride=1, self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x): def forward(self, x):
h_ = x h_ = x
...@@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module): ...@@ -139,41 +118,38 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_) v = self.v(h_)
# compute attention # compute attention
b,c,h,w = q.shape b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c') q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, 'b c h w -> b c (h w)') k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum('bij,bjk->bik', q, k) w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2) w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values # attend to values
v = rearrange(v, 'b c h w -> b c (h w)') v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, 'b i j -> b j i') w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_) h_ = self.proj_out(h_)
return x+h_ return x + h_
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__() super().__init__()
inner_dim = dim_head * heads inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim) context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5 self.scale = dim_head**-0.5
self.heads = heads self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
...@@ -183,31 +159,34 @@ class CrossAttention(nn.Module): ...@@ -183,31 +159,34 @@ class CrossAttention(nn.Module):
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask): if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)') mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
attn = sim.softmax(dim=-1) attn = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', attn, v) out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__() super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, self.attn2 = CrossAttention(
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim) self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
...@@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module): ...@@ -228,29 +207,23 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action. Then apply standard transformer action.
Finally, reshape to image Finally, reshape to image
""" """
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
inner_dim = n_heads * d_head inner_dim = n_heads * d_head
self.norm = Normalize(in_channels) self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels, self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) [
for d in range(depth)] BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
) )
self.proj_out = zero_module(nn.Conv2d(inner_dim, self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None): def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention # note: if no context is given, cross-attention defaults to self-attention
...@@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module): ...@@ -258,13 +231,14 @@ class SpatialTransformer(nn.Module):
x_in = x x_in = x
x = self.norm(x) x = self.norm(x)
x = self.proj_in(x) x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c') x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks: for block in self.transformer_blocks:
x = block(x, context=context) x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x) x = self.proj_out(x)
return x + x_in return x + x_in
def convert_module_to_f16(l): def convert_module_to_f16(l):
""" """
Convert primitive modules to float16. Convert primitive modules to float16.
...@@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module): ...@@ -386,7 +360,7 @@ class AttentionPool2d(nn.Module):
output_dim: int = None, output_dim: int = None,
): ):
super().__init__() super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.positional_embedding = nn.Parameter(torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels self.num_heads = embed_dim // num_heads_channels
...@@ -453,9 +427,7 @@ class Upsample(nn.Module): ...@@ -453,9 +427,7 @@ class Upsample(nn.Module):
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.channels assert x.shape[1] == self.channels
if self.dims == 3: if self.dims == 3:
x = F.interpolate( x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else: else:
x = F.interpolate(x, scale_factor=2, mode="nearest") x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv: if self.use_conv:
...@@ -472,7 +444,7 @@ class Downsample(nn.Module): ...@@ -472,7 +444,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions. downsampling occurs in the inner-two dimensions.
""" """
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
...@@ -480,9 +452,7 @@ class Downsample(nn.Module): ...@@ -480,9 +452,7 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2) stride = 2 if dims != 3 else (1, 2, 2)
if use_conv: if use_conv:
self.op = conv_nd( self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
...@@ -558,17 +528,13 @@ class ResBlock(TimestepBlock): ...@@ -558,17 +528,13 @@ class ResBlock(TimestepBlock):
normalization(self.out_channels), normalization(self.out_channels),
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
dims, channels, self.out_channels, 3, padding=1
)
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
...@@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y): ...@@ -686,7 +652,7 @@ def count_flops_attn(model, _x, y):
# We perform two matmuls with the same number of ops. # We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes # The first computes the weight matrix, the second computes
# the combination of the value vectors. # the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += torch.DoubleTensor([matmul_ops]) model.total_ops += torch.DoubleTensor([matmul_ops])
...@@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module): ...@@ -710,9 +676,7 @@ class QKVAttentionLegacy(nn.Module):
ch = width // (3 * self.n_heads) ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch)) scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum( weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
a = torch.einsum("bts,bcs->bct", weight, v) a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
...@@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -810,19 +774,23 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
) )
if use_spatial_transformer: if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' assert (
context_dim is not None
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
if context_dim is not None: if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' assert (
use_spatial_transformer
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
if num_heads_upsample == -1: if num_heads_upsample == -1:
num_heads_upsample = num_heads num_heads_upsample = num_heads
if num_heads == -1: if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' assert num_head_channels != -1, "Either num_heads or num_head_channels has to be set"
if num_head_channels == -1: if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' assert num_heads != -1, "Either num_heads or num_head_channels has to be set"
self.image_size = image_size self.image_size = image_size
self.in_channels = in_channels self.in_channels = in_channels
...@@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -852,11 +820,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) )
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
...@@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -883,7 +847,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(
...@@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -892,7 +856,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
) )
) )
...@@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -914,9 +880,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
) )
ch = out_ch ch = out_ch
...@@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -930,7 +894,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential( self.middle_block = TimestepEmbedSequential(
ResBlock( ResBlock(
...@@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -947,9 +911,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads, num_heads=num_heads,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim if not use_spatial_transformer
), else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim),
ResBlock( ResBlock(
ch, ch,
time_embed_dim, time_embed_dim,
...@@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -984,7 +948,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads = ch // num_head_channels num_heads = ch // num_head_channels
dim_head = num_head_channels dim_head = num_head_channels
if legacy: if legacy:
#num_heads = 1 # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append( layers.append(
AttentionBlock( AttentionBlock(
...@@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -993,7 +957,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
num_heads=num_heads_upsample, num_heads=num_heads_upsample,
num_head_channels=dim_head, num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( )
if not use_spatial_transformer
else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
) )
) )
...@@ -1026,7 +992,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1026,7 +992,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(
normalization(ch), normalization(ch),
conv_nd(dims, model_channels, n_embed, 1), conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
) )
def convert_to_fp16(self): def convert_to_fp16(self):
...@@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): ...@@ -1045,7 +1011,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
self.middle_block.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs): def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
""" """
Apply the model to an input batch. Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs. :param x: an [N x C x ...] Tensor of inputs.
...@@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module): ...@@ -1108,7 +1074,7 @@ class EncoderUNetModel(nn.Module):
use_new_attention_order=False, use_new_attention_order=False,
pool="adaptive", pool="adaptive",
*args, *args,
**kwargs **kwargs,
): ):
super().__init__() super().__init__()
...@@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module): ...@@ -1137,11 +1103,7 @@ class EncoderUNetModel(nn.Module):
) )
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) )
self._feature_size = model_channels self._feature_size = model_channels
input_block_chans = [model_channels] input_block_chans = [model_channels]
...@@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module): ...@@ -1189,9 +1151,7 @@ class EncoderUNetModel(nn.Module):
down=True, down=True,
) )
if resblock_updown if resblock_updown
else Downsample( else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
ch, conv_resample, dims=dims, out_channels=out_ch
)
) )
) )
ch = out_ch ch = out_ch
...@@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module): ...@@ -1239,9 +1199,7 @@ class EncoderUNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
AttentionPool2d( AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels),
(image_size // ds), ch, num_head_channels, out_channels
),
) )
elif pool == "spatial": elif pool == "spatial":
self.out = nn.Sequential( self.out = nn.Sequential(
...@@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module): ...@@ -1296,4 +1254,3 @@ class EncoderUNetModel(nn.Module):
else: else:
h = h.type(x.dtype) h = h.type(x.dtype)
return self.out(h) return self.out(h)
...@@ -20,10 +20,9 @@ from typing import Optional, Union ...@@ -20,10 +20,9 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_model.pt"
......
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math import math
import torch
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment