Commit 528b1293 authored by anton-l's avatar anton-l
Browse files

make style

parents f23bb3e8 cbb19ee8
#!/usr/bin/env python3
import os
import pathlib
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"]
for model_id in model_ids:
......
......@@ -14,13 +14,13 @@
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
import tqdm
from diffusers import DiffusionPipeline
class DDIM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......@@ -34,12 +34,16 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
......@@ -50,8 +54,14 @@ class DDIM(DiffusionPipeline):
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
# compute relevant coefficients
coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# model forward
with torch.no_grad():
......
#!/usr/bin/env python3
# !pip install diffusers
from modeling_ddim import DDIM
import PIL.Image
import numpy as np
import PIL.Image
from modeling_ddim import DDIM
model_id = "fusing/ddpm-cifar10"
model_id = "fusing/ddpm-lsun-bedroom"
......
#!/usr/bin/env python3
import os
import pathlib
from modeling_ddpm import DDPM
import PIL.Image
import numpy as np
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
import PIL.Image
from modeling_ddpm import DDPM
model_ids = [
"ddpm-lsun-cat",
"ddpm-lsun-cat-ema",
"ddpm-lsun-church-ema",
"ddpm-lsun-church",
"ddpm-lsun-bedroom",
"ddpm-lsun-bedroom-ema",
"ddpm-cifar10-ema",
"ddpm-cifar10",
"ddpm-celeba-hq",
"ddpm-celeba-hq-ema",
]
for model_id in model_ids:
path = os.path.join("/home/patrick/images/hf", model_id)
......
......@@ -14,13 +14,13 @@
# limitations under the License.
from diffusers import DiffusionPipeline
import tqdm
import torch
import tqdm
from diffusers import DiffusionPipeline
class DDPM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
......@@ -31,13 +31,25 @@ class DDPM(DiffusionPipeline):
self.unet.to(torch_device)
# 1. Sample gaussian noise
image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(self.noise_scheduler.get_alpha(t)) / (1 - self.noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) * self.noise_scheduler.get_beta(t) / (1 - self.noise_scheduler.get_alpha_prod(t))
image_coeff = (
(1 - self.noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.noise_scheduler.get_alpha(t))
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1))
* self.noise_scheduler.get_beta(t)
/ (1 - self.noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
with torch.no_grad():
......@@ -50,7 +62,9 @@ class DDPM(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, generator=generator)
prev_variance = self.noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
......
import torch
from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from diffusers import (
ClassifierFreeGuidanceScheduler,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from modeling_glide import GLIDE, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer
......@@ -22,7 +27,9 @@ config = CLIPTextConfig(
use_padding_embeddings=True,
)
model = CLIPTextModel(config).eval()
tokenizer = GPT2Tokenizer("./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>")
tokenizer = GPT2Tokenizer(
"./glide-base/tokenizer/vocab.json", "./glide-base/tokenizer/merges.txt", pad_token="<|endoftext|>"
)
hf_encoder = model.text_model
......@@ -97,10 +104,13 @@ superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = GlideDDIMScheduler(timesteps=1000, beta_schedule="linear")
glide = GLIDE(text_unet=text2im_model, text_noise_scheduler=text_scheduler, text_encoder=model, tokenizer=tokenizer,
upscale_unet=superres_model, upscale_noise_scheduler=upscale_scheduler)
glide = GLIDE(
text_unet=text2im_model,
text_noise_scheduler=text_scheduler,
text_encoder=model,
tokenizer=tokenizer,
upscale_unet=superres_model,
upscale_noise_scheduler=upscale_scheduler,
)
glide.save_pretrained("./glide-base")
......@@ -18,10 +18,20 @@ import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
import tqdm
from diffusers import (
ClassifierFreeGuidanceScheduler,
DiffusionPipeline,
GlideDDIMScheduler,
GLIDESuperResUNetModel,
GLIDETextToImageUNetModel,
)
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
......@@ -34,14 +44,6 @@ from transformers.utils import (
)
import numpy as np
import torch
import tqdm
from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from transformers import GPT2Tokenizer
#####################
# START OF THE CLIP MODEL COPY-PASTE (with a modified attention module)
#####################
......@@ -725,12 +727,16 @@ class GLIDE(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel,
upscale_noise_scheduler: GlideDDIMScheduler
upscale_noise_scheduler: GlideDDIMScheduler,
):
super().__init__()
self.register_modules(
text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer,
upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler
text_unet=text_unet,
text_noise_scheduler=text_noise_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
)
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
......@@ -746,9 +752,7 @@ class GLIDE(DiffusionPipeline):
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(
scheduler.posterior_log_variance_clipped, t, x_t.shape
)
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
......@@ -869,19 +873,30 @@ class GLIDE(DiffusionPipeline):
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
image = self.upscale_noise_scheduler.sample_noise(
image = (
self.upscale_noise_scheduler.sample_noise(
(batch_size, 3, 256, 256), device=torch_device, generator=generator
) * upsample_temp
)
* upsample_temp
)
num_timesteps = len(self.upscale_noise_scheduler)
for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)):
for t in tqdm.tqdm(
reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)
):
# i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt(
self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta(
t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
image_coeff = (
(1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* torch.sqrt(self.upscale_noise_scheduler.get_alpha(t))
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1))
* self.upscale_noise_scheduler.get_beta(t)
/ (1 - self.upscale_noise_scheduler.get_alpha_prod(t))
)
# ii) predict noise residual
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
......@@ -895,8 +910,9 @@ class GLIDE(DiffusionPipeline):
prev_image = clipped_coeff * pred_mean + image_coeff * image
# iv) sample variance
prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device,
generator=generator)
prev_variance = self.upscale_noise_scheduler.sample_variance(
t, prev_image.shape, device=torch_device, generator=generator
)
# v) sample x_{t-1} ~ N(prev_image, prev_variance)
sampled_prev_image = prev_image + prev_variance
......
import torch
from diffusers import DiffusionPipeline
import PIL.Image
from diffusers import DiffusionPipeline
generator = torch.Generator()
generator = generator.manual_seed(0)
......@@ -15,7 +17,7 @@ img = pipeline("a crayon drawing of a corgi", generator)
# process image to PIL
img = img.squeeze(0)
img = ((img + 1)*127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
img = ((img + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(img)
# save image
......
......@@ -84,6 +84,7 @@ _deps = [
"isort>=5.5.4",
"numpy",
"pytest",
"regex!=2019.12.17",
"requests",
"torch>=1.4",
"torchvision",
......@@ -168,6 +169,7 @@ install_requires = [
deps["filelock"],
deps["huggingface-hub"],
deps["numpy"],
deps["regex"],
deps["requests"],
deps["torch"],
deps["torchvision"],
......
......@@ -6,7 +6,7 @@ __version__ = "0.0.1"
from .modeling_utils import ModelMixin
from .models.unet import UNetModel
from .models.unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .models.vqvae import VQModel
from .pipeline_utils import DiffusionPipeline
......
......@@ -23,13 +23,13 @@ import os
import re
from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from huggingface_hub import hf_hub_download
from requests import HTTPError
from . import __version__
from .utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
......@@ -37,9 +37,6 @@ from .utils import (
)
from . import __version__
logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
......@@ -95,9 +92,7 @@ class ConfigMixin:
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(
pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
)
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
......@@ -157,16 +152,16 @@ class ConfigMixin:
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
"`use_auth_token=True`."
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed"
" on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token"
" having permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and"
" pass `use_auth_token=True`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
"available revisions."
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
" this model name. Check the model page at"
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
......@@ -174,14 +169,16 @@ class ConfigMixin:
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
f" containing a {cls.config_name} file.\nCheckout your internet connection or see how to run the"
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
......@@ -195,9 +192,7 @@ class ConfigMixin:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{config_file}' is not a valid JSON file."
)
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
return config_dict
......
......@@ -3,29 +3,15 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.9.0",
"black": "black~=22.0,>=22.3",
"codecarbon": "codecarbon==1.2.0",
"dataclasses": "dataclasses",
"datasets": "datasets",
"GitPython": "GitPython<3.1.19",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
"importlib_metadata": "importlib_metadata",
"filelock": "filelock",
"flake8": "flake8>=3.8.3",
"huggingface-hub": "huggingface-hub",
"isort": "isort>=5.5.4",
"numpy": "numpy>=1.17",
"numpy": "numpy",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.7.0",
"regex": "regex!=2019.12.17",
"requests": "requests",
"sagemaker": "sagemaker>=2.31.0",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.4",
"torchaudio": "torchaudio",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
"uvicorn": "uvicorn",
"torchvision": "torchvision",
}
......@@ -23,7 +23,8 @@ from pathlib import Path
from typing import Dict, Optional, Union
from huggingface_hub import cached_download
from .utils import HF_MODULES_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, logging
from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......
......@@ -20,8 +20,8 @@ from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import Tensor, device
from requests import HTTPError
from huggingface_hub import hf_hub_download
from requests import HTTPError
from .utils import (
CONFIG_NAME,
......@@ -379,10 +379,13 @@ class ModelMixin(torch.nn.Module):
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}.")
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {model_file}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
......
......@@ -17,6 +17,6 @@
# limitations under the License.
from .unet import UNetModel
from .unet_glide import GLIDETextToImageUNetModel, GLIDESuperResUNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .unet_ldm import UNetLDMModel
from .vqvae import VQModel
......@@ -25,8 +25,8 @@ from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from torchvision import transforms, utils
from PIL import Image
from torchvision import transforms, utils
from tqdm import tqdm
from ..configuration_utils import ConfigMixin
......@@ -335,19 +335,22 @@ class UNetModel(ModelMixin, ConfigMixin):
# dataset classes
class Dataset(data.Dataset):
def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
self.transform = transforms.Compose([
self.transform = transforms.Compose(
[
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor()
])
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.paths)
......@@ -359,7 +362,7 @@ class Dataset(data.Dataset):
# trainer class
class EMA():
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
......
......@@ -664,7 +664,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
transformer_dim=512
transformer_dim=512,
):
super().__init__(
in_channels=in_channels,
......@@ -683,7 +683,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
transformer_dim=transformer_dim,
)
self.register(
in_channels=in_channels,
......@@ -702,7 +702,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
transformer_dim=transformer_dim
transformer_dim=transformer_dim,
)
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
......
This diff is collapsed.
......@@ -20,10 +20,9 @@ from typing import Optional, Union
from huggingface_hub import snapshot_download
from .utils import logging, DIFFUSERS_CACHE
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .utils import DIFFUSERS_CACHE, logging
INDEX_FILE = "diffusion_model.pt"
......
......@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import math
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar
from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment