Unverified Commit 988369a0 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Merge branch 'main' into grad-tts

parents 5a3467e6 bed32182
...@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -44,7 +44,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
clip_predicted_image=clip_predicted_image, clip_predicted_image=clip_predicted_image,
) )
self.timesteps = int(timesteps) self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image self.clip_image = clip_predicted_image
self.variance_type = variance_type self.variance_type = variance_type
......
...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -84,7 +84,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order
)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
return self.warmup_time_steps[num_inference_steps] return self.warmup_time_steps[num_inference_steps]
...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -137,7 +139,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
at = alphas_cump[t + 1].view(-1, 1, 1, 1) at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) x_delta = (at_next - at) * (
(1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x
- 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et
)
x_next = x + x_delta x_next = x + x_delta
return x_next return x_next
......
...@@ -49,16 +49,16 @@ _tqdm_active = True ...@@ -49,16 +49,16 @@ _tqdm_active = True
def _get_default_logging_level(): def _get_default_logging_level():
""" """
If TRANSFORMERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
not - fall back to `_default_log_level` not - fall back to `_default_log_level`
""" """
env_level_str = os.getenv("TRANSFORMERS_VERBOSITY", None) env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
if env_level_str: if env_level_str:
if env_level_str in log_levels: if env_level_str in log_levels:
return log_levels[env_level_str] return log_levels[env_level_str]
else: else:
logging.getLogger().warning( logging.getLogger().warning(
f"Unknown option TRANSFORMERS_VERBOSITY={env_level_str}, " f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
f"has to be one of: { ', '.join(log_levels.keys()) }" f"has to be one of: { ', '.join(log_levels.keys()) }"
) )
return _default_log_level return _default_log_level
...@@ -126,14 +126,14 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: ...@@ -126,14 +126,14 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
def get_verbosity() -> int: def get_verbosity() -> int:
""" """
Return the current level for the 🤗 Transformers's root logger as an int. Return the current level for the 🤗 Diffusers' root logger as an int.
Returns: Returns:
`int`: The logging level. `int`: The logging level.
<Tip> <Tip>
🤗 Transformers has following logging levels: 🤗 Diffusers has following logging levels:
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL` - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
- 40: `diffusers.logging.ERROR` - 40: `diffusers.logging.ERROR`
...@@ -149,7 +149,7 @@ def get_verbosity() -> int: ...@@ -149,7 +149,7 @@ def get_verbosity() -> int:
def set_verbosity(verbosity: int) -> None: def set_verbosity(verbosity: int) -> None:
""" """
Set the verbosity level for the 🤗 Transformers's root logger. Set the verbosity level for the 🤗 Diffusers' root logger.
Args: Args:
verbosity (`int`): verbosity (`int`):
...@@ -187,7 +187,7 @@ def set_verbosity_error(): ...@@ -187,7 +187,7 @@ def set_verbosity_error():
def disable_default_handler() -> None: def disable_default_handler() -> None:
"""Disable the default handler of the HuggingFace Transformers's root logger.""" """Disable the default handler of the HuggingFace Diffusers' root logger."""
_configure_library_root_logger() _configure_library_root_logger()
...@@ -196,7 +196,7 @@ def disable_default_handler() -> None: ...@@ -196,7 +196,7 @@ def disable_default_handler() -> None:
def enable_default_handler() -> None: def enable_default_handler() -> None:
"""Enable the default handler of the HuggingFace Transformers's root logger.""" """Enable the default handler of the HuggingFace Diffusers' root logger."""
_configure_library_root_logger() _configure_library_root_logger()
...@@ -205,7 +205,7 @@ def enable_default_handler() -> None: ...@@ -205,7 +205,7 @@ def enable_default_handler() -> None:
def add_handler(handler: logging.Handler) -> None: def add_handler(handler: logging.Handler) -> None:
"""adds a handler to the HuggingFace Transformers's root logger.""" """adds a handler to the HuggingFace Diffusers' root logger."""
_configure_library_root_logger() _configure_library_root_logger()
...@@ -214,7 +214,7 @@ def add_handler(handler: logging.Handler) -> None: ...@@ -214,7 +214,7 @@ def add_handler(handler: logging.Handler) -> None:
def remove_handler(handler: logging.Handler) -> None: def remove_handler(handler: logging.Handler) -> None:
"""removes given handler from the HuggingFace Transformers's root logger.""" """removes given handler from the HuggingFace Diffusers' root logger."""
_configure_library_root_logger() _configure_library_root_logger()
...@@ -233,7 +233,7 @@ def disable_propagation() -> None: ...@@ -233,7 +233,7 @@ def disable_propagation() -> None:
def enable_propagation() -> None: def enable_propagation() -> None:
""" """
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to
prevent double logging if the root logger has been configured. prevent double logging if the root logger has been configured.
""" """
...@@ -243,7 +243,7 @@ def enable_propagation() -> None: ...@@ -243,7 +243,7 @@ def enable_propagation() -> None:
def enable_explicit_format() -> None: def enable_explicit_format() -> None:
""" """
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows: Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
``` ```
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
``` ```
...@@ -258,7 +258,7 @@ def enable_explicit_format() -> None: ...@@ -258,7 +258,7 @@ def enable_explicit_format() -> None:
def reset_format() -> None: def reset_format() -> None:
""" """
Resets the formatting for HuggingFace Transformers's loggers. Resets the formatting for HuggingFace Diffusers' loggers.
All handlers currently bound to the root logger are affected by this method. All handlers currently bound to the root logger are affected by this method.
""" """
...@@ -270,10 +270,10 @@ def reset_format() -> None: ...@@ -270,10 +270,10 @@ def reset_format() -> None:
def warning_advice(self, *args, **kwargs): def warning_advice(self, *args, **kwargs):
""" """
This method is identical to `logger.warninging()`, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
warning will not be printed warning will not be printed
""" """
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False) no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
if no_advisory_warnings: if no_advisory_warnings:
return return
self.warning(*args, **kwargs) self.warning(*args, **kwargs)
......
...@@ -19,7 +19,18 @@ import unittest ...@@ -19,7 +19,18 @@ import unittest
import torch import torch
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers import (
BDDM,
DDIM,
DDPM,
GLIDE,
PNDM,
DDIMScheduler,
DDPMScheduler,
LatentDiffusion,
PNDMScheduler,
UNetModel,
)
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.pipelines.pipeline_bddm import DiffWave
...@@ -214,6 +225,21 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -214,6 +225,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)
...@@ -229,17 +255,3 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -229,17 +255,3 @@ class PipelineTesterMixin(unittest.TestCase):
_ = BDDM.from_pretrained(tmpdirname) _ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class # check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname) _ = DiffusionPipeline.from_pretrained(tmpdirname)
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment