Commit d0032c60 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor naming

parent 320506c7
...@@ -48,7 +48,7 @@ The class provides functionality to compute previous image according to alpha, b ...@@ -48,7 +48,7 @@ The class provides functionality to compute previous image according to alpha, b
**Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ... **Diffusion Pipeline**: End-to-end pipeline that includes multiple diffusion models, possible text encoders, ...
*Examples*: GLIDE, Latent-Diffusion, Imagen, DALL-E 2 *Examples*: Glide, Latent-Diffusion, Imagen, DALL-E 2
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/> <img src="https://user-images.githubusercontent.com/10695622/174348898-481bd7c2-5457-4830-89bc-f0907756f64c.jpeg" width="550"/>
...@@ -190,7 +190,7 @@ image_pil.save("test.png") ...@@ -190,7 +190,7 @@ image_pil.save("test.png")
[Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing) [Diffuser](https://diffusion-planning.github.io/) for planning in reinforcement learning: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1TmBmlYeKUZSkUZoJqfBmaicVTKx6nN1R?usp=sharing)
### 2. `diffusers` as a collection of popular Diffusion systems (GLIDE, Dalle, ...) ### 2. `diffusers` as a collection of popular Diffusion systems (Glide, Dalle, ...)
For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines). For more examples see [pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
......
import torch import torch
from torch import nn from torch import nn
from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GLIDESuperResUNetModel, GLIDETextToImageUNetModel from diffusers import ClassifierFreeGuidanceScheduler, DDIMScheduler, GlideSuperResUNetModel, GlideTextToImageUNetModel
from diffusers.pipelines.pipeline_glide import GLIDE, CLIPTextModel from diffusers.pipelines.pipeline_glide import Glide, CLIPTextModel
from transformers import CLIPTextConfig, GPT2Tokenizer from transformers import CLIPTextConfig, GPT2Tokenizer
...@@ -55,7 +55,7 @@ for layer_idx in range(config.num_hidden_layers): ...@@ -55,7 +55,7 @@ for layer_idx in range(config.num_hidden_layers):
### Convert the Text-to-Image UNet ### Convert the Text-to-Image UNet
text2im_model = GLIDETextToImageUNetModel( text2im_model = GlideTextToImageUNetModel(
in_channels=3, in_channels=3,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -80,7 +80,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule=" ...@@ -80,7 +80,7 @@ text_scheduler = ClassifierFreeGuidanceScheduler(timesteps=1000, beta_schedule="
# wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt # wget https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt
ups_state_dict = torch.load("upsample.pt", map_location="cpu") ups_state_dict = torch.load("upsample.pt", map_location="cpu")
superres_model = GLIDESuperResUNetModel( superres_model = GlideSuperResUNetModel(
in_channels=6, in_channels=6,
model_channels=192, model_channels=192,
out_channels=6, out_channels=6,
...@@ -101,7 +101,7 @@ upscale_scheduler = DDIMScheduler( ...@@ -101,7 +101,7 @@ upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt" timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
) )
glide = GLIDE( glide = Glide(
text_unet=text2im_model, text_unet=text2im_model,
text_noise_scheduler=text_scheduler, text_noise_scheduler=text_scheduler,
text_encoder=model, text_encoder=model,
......
# flake8: noqa # flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_transformers_available from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
__version__ = "0.0.4" __version__ = "0.0.4"
...@@ -16,8 +16,14 @@ from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMSche ...@@ -16,8 +16,14 @@ from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMSche
if is_transformers_available(): if is_transformers_available():
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .models.unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_grad_tts import UNetGradTTSModel
from .pipelines import GLIDE, GradTTS, LatentDiffusion from .pipelines import Glide, LatentDiffusion
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
if is_transformers_available() and is_inflect_available() and is_unidecode_available():
from .pipelines import GradTTS
else:
from .utils.dummy_transformers_and_inflect_and_unidecode_objects import *
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from .unet import UNetModel from .unet import UNetModel
from .unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel, GLIDEUNetModel from .unet_glide import GlideSuperResUNetModel, GlideTextToImageUNetModel, GlideUNetModel
from .unet_grad_tts import UNetGradTTSModel from .unet_grad_tts import UNetGradTTSModel
from .unet_ldm import UNetLDMModel from .unet_ldm import UNetLDMModel
from .unet_rl import TemporalUNet from .unet_rl import TemporalUNet
...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module): ...@@ -388,7 +388,7 @@ class QKVAttention(nn.Module):
return a.reshape(bs, -1, length) return a.reshape(bs, -1, length)
class GLIDEUNetModel(ModelMixin, ConfigMixin): class GlideUNetModel(ModelMixin, ConfigMixin):
""" """
The full UNet model with attention and timestep embedding. The full UNet model with attention and timestep embedding.
...@@ -641,7 +641,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin): ...@@ -641,7 +641,7 @@ class GLIDEUNetModel(ModelMixin, ConfigMixin):
return self.out(h) return self.out(h)
class GLIDETextToImageUNetModel(GLIDEUNetModel): class GlideTextToImageUNetModel(GlideUNetModel):
""" """
A UNetModel that performs super-resolution. A UNetModel that performs super-resolution.
...@@ -734,7 +734,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel): ...@@ -734,7 +734,7 @@ class GLIDETextToImageUNetModel(GLIDEUNetModel):
return self.out(h) return self.out(h)
class GLIDESuperResUNetModel(GLIDEUNetModel): class GlideSuperResUNetModel(GlideUNetModel):
""" """
A UNetModel that performs super-resolution. A UNetModel that performs super-resolution.
......
from ..utils import is_transformers_available from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDM from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM from .pipeline_ddpm import DDPM
...@@ -6,6 +6,9 @@ from .pipeline_pndm import PNDM ...@@ -6,6 +6,9 @@ from .pipeline_pndm import PNDM
if is_transformers_available(): if is_transformers_available():
from .pipeline_glide import GLIDE from .pipeline_glide import Glide
from .pipeline_grad_tts import GradTTS
from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_latent_diffusion import LatentDiffusion
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTS
...@@ -6,20 +6,9 @@ from shutil import copyfile ...@@ -6,20 +6,9 @@ from shutil import copyfile
import torch import torch
import inflect
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from unidecode import unidecode
try:
from unidecode import unidecode
except:
print("unidecode is not installed")
pass
try:
import inflect
except:
print("inflect is not installed")
pass
valid_symbols = [ valid_symbols = [
...@@ -234,12 +223,7 @@ def english_cleaners(text): ...@@ -234,12 +223,7 @@ def english_cleaners(text):
return text return text
try: _inflect = inflect.engine()
_inflect = inflect.engine()
except:
print("inflect is not installed")
_inflect = None
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
......
...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo ...@@ -30,7 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging from ..utils import logging
...@@ -711,14 +711,14 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): ...@@ -711,14 +711,14 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res + torch.zeros(broadcast_shape, device=timesteps.device) return res + torch.zeros(broadcast_shape, device=timesteps.device)
class GLIDE(DiffusionPipeline): class Glide(DiffusionPipeline):
def __init__( def __init__(
self, self,
text_unet: GLIDETextToImageUNetModel, text_unet: GlideTextToImageUNetModel,
text_noise_scheduler: DDPMScheduler, text_noise_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GlideSuperResUNetModel,
upscale_noise_scheduler: DDIMScheduler, upscale_noise_scheduler: DDIMScheduler,
): ):
super().__init__() super().__init__()
......
...@@ -73,7 +73,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -73,7 +73,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -132,7 +132,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -132,7 +132,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
std_dev_t = eta * variance ** (0.5) std_dev_t = eta * variance ** (0.5)
if use_clipped_residual: if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE # the residual is always re-derived from the clipped x_0 in Glide
residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
......
...@@ -76,7 +76,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -76,7 +76,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "linear": elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
...@@ -108,7 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,7 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif variance_type == "fixed_large": elif variance_type == "fixed_large":
variance = self.betas[t] variance = self.betas[t]
elif variance_type == "fixed_large_log": elif variance_type == "fixed_large_log":
# GLIDE max_log # Glide max_log
variance = self.log(self.betas[t]) variance = self.log(self.betas[t])
return variance return variance
......
...@@ -66,7 +66,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -66,7 +66,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2": elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule # Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps) self.betas = betas_for_alpha_bar(timesteps)
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
......
...@@ -45,10 +45,34 @@ except importlib_metadata.PackageNotFoundError: ...@@ -45,10 +45,34 @@ except importlib_metadata.PackageNotFoundError:
_transformers_available = False _transformers_available = False
_inflect_available = importlib.util.find_spec("inflect") is not None
try:
_inflect_version = importlib_metadata.version("inflect")
logger.debug(f"Successfully imported inflect version {_inflect_version}")
except importlib_metadata.PackageNotFoundError:
_inflect_available = False
_unidecode_available = importlib.util.find_spec("unidecode") is not None
try:
_unidecode_version = importlib_metadata.version("unidecode")
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
except importlib_metadata.PackageNotFoundError:
_unidecode_available = False
def is_transformers_available(): def is_transformers_available():
return _transformers_available return _transformers_available
def is_inflect_available():
return _inflect_available
def is_unidecode_available():
return _unidecode_available
class RepositoryNotFoundError(HTTPError): class RepositoryNotFoundError(HTTPError):
""" """
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
...@@ -70,9 +94,23 @@ TRANSFORMERS_IMPORT_ERROR = """ ...@@ -70,9 +94,23 @@ TRANSFORMERS_IMPORT_ERROR = """
""" """
UNIDECODE_IMPORT_ERROR = """
{0} requires the unidecode library but it was not found in your environment. You can install it with pip:
`pip install Unidecode`
"""
INFLECT_IMPORT_ERROR = """
{0} requires the inflect library but it was not found in your environment. You can install it with pip:
`pip install inflect`
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)), ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
] ]
) )
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends
class GradTTS(metaclass=DummyObject):
_backends = ["transformers", "inflect", "unidecode"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers", "inflect", "unidecode"])
...@@ -3,21 +3,21 @@ ...@@ -3,21 +3,21 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class GLIDESuperResUNetModel(metaclass=DummyObject): class GlideSuperResUNetModel(metaclass=DummyObject):
_backends = ["transformers"] _backends = ["transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"]) requires_backends(self, ["transformers"])
class GLIDETextToImageUNetModel(metaclass=DummyObject): class GlideTextToImageUNetModel(metaclass=DummyObject):
_backends = ["transformers"] _backends = ["transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"]) requires_backends(self, ["transformers"])
class GLIDEUNetModel(metaclass=DummyObject): class GlideUNetModel(metaclass=DummyObject):
_backends = ["transformers"] _backends = ["transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -31,10 +31,7 @@ class UNetGradTTSModel(metaclass=DummyObject): ...@@ -31,10 +31,7 @@ class UNetGradTTSModel(metaclass=DummyObject):
requires_backends(self, ["transformers"]) requires_backends(self, ["transformers"])
GLIDE = None class Glide(metaclass=DummyObject):
class GradTTS(metaclass=DummyObject):
_backends = ["transformers"] _backends = ["transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -21,17 +21,17 @@ import unittest ...@@ -21,17 +21,17 @@ import unittest
import numpy as np import numpy as np
import torch import torch
import pytest
from diffusers import ( from diffusers import (
BDDM, BDDM,
DDIM, DDIM,
DDPM, DDPM,
GLIDE, Glide,
PNDM, PNDM,
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
GLIDESuperResUNetModel, GlideSuperResUNetModel,
GLIDETextToImageUNetModel, GlideTextToImageUNetModel,
GradTTS,
LatentDiffusion, LatentDiffusion,
PNDMScheduler, PNDMScheduler,
UNetGradTTSModel, UNetGradTTSModel,
...@@ -247,13 +247,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -247,13 +247,13 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053]) expected_output_slice = torch.tensor([0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDESuperResUNetModel model_class = GlideSuperResUNetModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -309,7 +309,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -309,7 +309,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = GLIDESuperResUNetModel.from_pretrained( model, loading_info = GlideSuperResUNetModel.from_pretrained(
"fusing/glide-super-res-dummy", output_loading_info=True "fusing/glide-super-res-dummy", output_loading_info=True
) )
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -321,7 +321,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -321,7 +321,7 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy") model = GlideSuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -342,8 +342,8 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): ...@@ -342,8 +342,8 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = GLIDETextToImageUNetModel model_class = GlideTextToImageUNetModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -401,7 +401,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -401,7 +401,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = GLIDETextToImageUNetModel.from_pretrained( model, loading_info = GlideTextToImageUNetModel.from_pretrained(
"fusing/unet-glide-text2im-dummy", output_loading_info=True "fusing/unet-glide-text2im-dummy", output_loading_info=True
) )
self.assertIsNotNone(model) self.assertIsNotNone(model)
...@@ -413,7 +413,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -413,7 +413,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = GLIDETextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy") model = GlideTextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -431,7 +431,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -431,7 +431,7 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
output, _ = torch.split(output, 3, dim=1) output, _ = torch.split(output, 3, dim=1)
output_slice = output[0, -1, -3:, -3:].flatten() output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([ 2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) expected_output_slice = torch.tensor([2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
...@@ -571,7 +571,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -571,7 +571,7 @@ class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
output_slice = output[0, -3:, -3:].flatten() output_slice = output[0, -3:, -3:].flatten()
# fmt: off # fmt: off
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
# fmt: on # fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
...@@ -689,7 +689,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -689,7 +689,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_glide_text2img(self): def test_glide_text2img(self):
model_id = "fusing/glide-base" model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id) glide = Glide.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi" prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -701,6 +701,20 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -701,6 +701,20 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_grad_tts(self):
model_id = "fusing/grad-tts-libri-tts"
grad_tts = GradTTS.from_pretrained(model_id)
text = "Hello world, I missed you so much."
# generate mel spectograms using text
mel_spec = grad_tts(text)
assert mel_spec.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (mel_spec.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)
......
...@@ -23,10 +23,9 @@ import re ...@@ -23,10 +23,9 @@ import re
PATH_TO_DIFFUSERS = "src/diffusers" PATH_TO_DIFFUSERS = "src/diffusers"
# Matches is_xxx_available() # Matches is_xxx_available()
_re_backend = re.compile(r"if is\_([a-z_]*)_available\(\)") _re_backend = re.compile(r"is\_([a-z_]*)_available\(\)")
# Matches from xxx import bla # Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+not\s+is\_[a-z]*\_available\(\)")
DUMMY_CONSTANT = """ DUMMY_CONSTANT = """
...@@ -54,7 +53,7 @@ def find_backend(line): ...@@ -54,7 +53,7 @@ def find_backend(line):
if len(backends) == 0: if len(backends) == 0:
return None return None
return backends[0] return "_and_".join(backends)
def read_init(): def read_init():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment