Commit 528b1293 authored by anton-l's avatar anton-l
Browse files

make style

parents f23bb3e8 cbb19ee8
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"] model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids: for model_id in model_ids:
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDIM(DiffusionPipeline): class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
...@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline): ...@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device) self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1 # get actual t and t-1
train_step = inference_step_times[t] train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1 prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas # compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step) alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
...@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline): ...@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt() beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients # compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta coeff_1 = (
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt() (alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward # model forward
with torch.no_grad(): with torch.no_grad():
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# !pip install diffusers # !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10" model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom" model_id = "fusing/ddpm-lsun-bedroom"
......
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import pathlib import pathlib
from modeling_ddpm import DDPM
import PIL.Image
import numpy as np import numpy as np
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"] import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids: for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id) path = os.path.join("/home/patrick/images/hf", model_id)
......
...@@ -14,13 +14,13 @@ ...@@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch import torch
import tqdm
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline): class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler): def __init__(self, unet, noise_scheduler):
super().__init__() super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler) self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
...@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline): ...@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device) self.unet.to(torch_device)
# 1. Sample gaussian noise # 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator) image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t)) (1 - self.noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.noise_scheduler.get_alpha(t))
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
* self.noise_scheduler.get_beta(t)
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
...@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline): ...@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator) prev_variance = self.noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, GLIDETextToImageUNetModel, GLIDESuperResUNetModel from diffusers import (
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -22,7 +27,9 @@ config = CLIPTextConfig( ...@@ -22,7 +27,9 @@ config = CLIPTextConfig(
use_padding_embeddings=True, use_padding_embeddings=True,
) )
model = CLIPTextModel(config).eval() model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>") tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model hf_encoder = model.text_model
...@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False) ...@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear") upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer, glide = GLIDE(
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler) text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base") glide.save_pretrained("./glide-base")
...@@ -18,10 +18,20 @@ import math ...@@ -18,10 +18,20 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig import tqdm
from diffusers import (
ClassifierFreeGuidanceScheduler,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
...@@ -34,14 +44,6 @@ from transformers.utils import ( ...@@ -34,14 +44,6 @@ from transformers.utils import (
) )
import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
##################### #####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module) # START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
##################### #####################
...@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline): ...@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler upscale_noise_scheduler: GlideDDIMScheduler,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, text_unet=text_unet,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler text_noise_scheduler=text_noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
) )
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
...@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline): ...@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
) )
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor( posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
assert ( assert (
posterior_mean.shape[0] posterior_mean.shape[0]
== posterior_variance.shape[0] == posterior_variance.shape[0]
...@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline): ...@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts. # A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997 upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise( image = (
(batch_size, 3, 256, 256), device=torch_device, generator=generator self.upscale_noise_scheduler.sample_noise(
) * upsample_temp (batch_size, 3, 256, 256), device=torch_device, generator=generator
)
* upsample_temp
)
num_timesteps = len(self.upscale_noise_scheduler) num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( image_coeff = (
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( * torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device) time_input = torch.tensor([t] * image.shape[0], device=torch_device)
...@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline): ...@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance # iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, prev_variance = self.upscale_noise_scheduler.sample_variance(
generator=generator) t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance) # v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance sampled_prev_image = prev_image + prev_variance
......
import torch import torch
from diffusers import DiffusionPipeline
import PIL.Image import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator() generator = torch.Generator()
generator = generator.manual_seed(0) generator = generator.manual_seed(0)
...@@ -15,8 +17,8 @@ img = pipeline("a crayon drawing of a corgi", generator) ...@@ -15,8 +17,8 @@ img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL # process image to PIL
img = img.squeeze(0) img = img.squeeze(0)
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img) image_pil = PIL.Image.fromarray(img)
# save image # save image
image_pil.save("test.png") image_pil.save("test.png")
\ No newline at end of file
...@@ -84,6 +84,7 @@ _deps = [ ...@@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"numpy", "numpy",
"pytest", "pytest",
"regex!=2019.12.17",
"requests", "requests",
"torch>=1.4", "torch>=1.4",
"torchvision", "torchvision",
...@@ -168,6 +169,7 @@ install_requires = [ ...@@ -168,6 +169,7 @@ install_requires = [
deps["filelock"], deps["filelock"],
deps["huggingface-hub"], deps["huggingface-hub"],
deps["numpy"], deps["numpy"],
deps["regex"],
deps["requests"], deps["requests"],
deps["torch"], deps["torch"],
deps["torchvision"], deps["torchvision"],
......
...@@ -6,7 +6,7 @@ __version__ = "0.0.1" ...@@ -6,7 +6,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models.unet import UNetModel from .models.unet import UNetModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
......
...@@ -23,13 +23,13 @@ import os ...@@ -23,13 +23,13 @@ import os
import re import re
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from . import __version__
from .utils import ( from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError, EntryNotFoundError,
RepositoryNotFoundError, RepositoryNotFoundError,
RevisionNotFoundError, RevisionNotFoundError,
...@@ -37,9 +37,6 @@ from .utils import ( ...@@ -37,9 +37,6 @@ from .utils import (
) )
from . import __version__
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json") _re_configuration_file = re.compile(r"config\.(.*)\.json")
...@@ -95,9 +92,7 @@ class ConfigMixin: ...@@ -95,9 +92,7 @@ class ConfigMixin:
@classmethod @classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict( config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
...@@ -157,16 +152,16 @@ class ConfigMixin: ...@@ -157,16 +152,16 @@ class ConfigMixin:
except RepositoryNotFoundError: except RepositoryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " " on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " " having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
"`use_auth_token=True`." " pass `use_auth_token=True`."
) )
except RevisionNotFoundError: except RevisionNotFoundError:
raise EnvironmentError( raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " " this model name. Check the model page at"
"available revisions." f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError( raise EnvironmentError(
...@@ -174,14 +169,16 @@ class ConfigMixin: ...@@ -174,14 +169,16 @@ class ConfigMixin:
) )
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the" f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." " run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
) )
except EnvironmentError: except EnvironmentError:
raise EnvironmentError( raise EnvironmentError(
...@@ -195,9 +192,7 @@ class ConfigMixin: ...@@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict # Load config dict
config_dict = cls._dict_from_json_file(config_file) config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError): except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError( raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
f"It looks like the config file at '{config_file}' is not a valid JSON file."
)
return config_dict return config_dict
......
...@@ -3,29 +3,15 @@ ...@@ -3,29 +3,15 @@
# 2. run `make deps_table_update`` # 2. run `make deps_table_update``
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3", "black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0", "filelock": "filelock",
"dataclasses": "dataclasses", "flake8": "flake8>=3.8.3",
"datasets": "datasets", "huggingface-hub": "huggingface-hub",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"numpy": "numpy>=1.17", "numpy": "numpy",
"pytest": "pytest", "pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17", "regex": "regex!=2019.12.17",
"requests": "requests", "requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4", "torch": "torch>=1.4",
"torchaudio": "torchaudio", "torchvision": "torchvision",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
} }
...@@ -23,7 +23,8 @@ from pathlib import Path ...@@ -23,7 +23,8 @@ from pathlib import Path
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from huggingface_hub import cached_download from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch import Tensor, device from torch import Tensor, device
from requests import HTTPError
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from requests import HTTPError
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module): ...@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
) )
except EntryNotFoundError: except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.") raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
)
except HTTPError as err: except HTTPError as err:
raise EnvironmentError( raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" "There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
) )
except ValueError: except ValueError:
raise EnvironmentError( raise EnvironmentError(
......
...@@ -17,6 +17,6 @@ ...@@ -17,6 +17,6 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .vqvae import VQModel from .vqvae import VQModel
\ No newline at end of file
...@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast ...@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam from torch.optim import Adam
from torch.utils import data from torch.utils import data
from torchvision import transforms, utils
from PIL import Image from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm from tqdm import tqdm
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
...@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin): ...@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes # dataset classes
class Dataset(data.Dataset): class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']): def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__() super().__init__()
self.folder = folder self.folder = folder
self.image_size = image_size self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose([ self.transform = transforms.Compose(
transforms.Resize(image_size), [
transforms.RandomHorizontalFlip(), transforms.Resize(image_size),
transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),
transforms.ToTensor() transforms.CenterCrop(image_size),
]) transforms.ToTensor(),
]
)
def __len__(self): def __len__(self):
return len(self.paths) return len(self.paths)
...@@ -359,7 +362,7 @@ class Dataset(data.Dataset): ...@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class # trainer class
class EMA(): class EMA:
def __init__(self, beta): def __init__(self, beta):
super().__init__() super().__init__()
self.beta = beta self.beta = beta
......
...@@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -647,24 +647,24 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
""" """
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
attention_resolutions=(2, 4, 8), attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
transformer_dim=512 transformer_dim=512,
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
...@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.register( self.register(
in_channels=in_channels, in_channels=in_channels,
...@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample, num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm, use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown, resblock_updown=resblock_updown,
transformer_dim=transformer_dim transformer_dim=transformer_dim,
) )
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4) self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
...@@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -737,23 +737,23 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
""" """
def __init__( def __init__(
self, self,
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
num_res_blocks=3, num_res_blocks=3,
attention_resolutions=(2, 4, 8), attention_resolutions=(2, 4, 8),
dropout=0, dropout=0,
channel_mult=(1, 2, 4, 8), channel_mult=(1, 2, 4, 8),
conv_resample=True, conv_resample=True,
dims=2, dims=2,
use_checkpoint=False, use_checkpoint=False,
use_fp16=False, use_fp16=False,
num_heads=1, num_heads=1,
num_head_channels=-1, num_head_channels=-1,
num_heads_upsample=-1, num_heads_upsample=-1,
use_scale_shift_norm=False, use_scale_shift_norm=False,
resblock_updown=False, resblock_updown=False,
): ):
super().__init__( super().__init__(
in_channels=in_channels, in_channels=in_channels,
...@@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel): ...@@ -809,4 +809,4 @@ class GLIDESuperResUNetModel(GLIDEUNetModel):
h = torch.cat([h, hs.pop()], dim=1) h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb) h = module(h, emb)
return self.out(h) return self.out(h)
\ No newline at end of file
This diff is collapsed.
...@@ -20,10 +20,9 @@ from typing import Optional, Union ...@@ -20,10 +20,9 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt" INDEX_FILE = "diffusion_model.pt"
...@@ -105,7 +104,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -105,7 +104,7 @@ class DiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Add docstrings Add docstrings
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
......
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import math import math
import torch
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment