Commit 528b1293 authored by anton-l's avatar anton-l
Browse files

make style

parents f23bb3e8 cbb19ee8
...@@ -11,12 +11,12 @@ ...@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from .schedulers_utils import linear_beta_schedule, betas_for_alpha_bar from .schedulers_utils import betas_for_alpha_bar, linear_beta_schedule
SAMPLING_CONFIG_NAME = "scheduler_config.json" SAMPLING_CONFIG_NAME = "scheduler_config.json"
...@@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin): ...@@ -26,12 +26,7 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME config_name = SAMPLING_CONFIG_NAME
def __init__( def __init__(self, timesteps=1000, beta_schedule="linear", variance_type="fixed_large"):
self,
timesteps=1000,
beta_schedule="linear",
variance_type="fixed_large"
):
super().__init__() super().__init__()
self.register( self.register(
timesteps=timesteps, timesteps=timesteps,
...@@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin): ...@@ -93,4 +88,4 @@ class GlideDDIMScheduler(nn.Module, ConfigMixin):
return torch.randn(shape, generator=generator).to(device) return torch.randn(shape, generator=generator).to(device)
def __len__(self): def __len__(self):
return self.num_timesteps return self.num_timesteps
\ No newline at end of file
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
# There's no way to ignore "F401 '...' imported but unused" warnings in this # There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
import os
# Copyright 2021 The HuggingFace Inc. team. All rights reserved. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -19,7 +21,7 @@ ...@@ -19,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
import os
hf_cache_home = os.path.expanduser( hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
......
...@@ -14,19 +14,19 @@ ...@@ -14,19 +14,19 @@
# limitations under the License. # limitations under the License.
import os
import random import random
import tempfile import tempfile
import unittest import unittest
import os
from distutils.util import strtobool from distutils.util import strtobool
import torch import torch
from diffusers import GaussianDDPMScheduler, UNetModel from diffusers import GaussianDDPMScheduler, UNetModel
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
from models.vision.ddpm.modeling_ddpm import DDPM from diffusers.pipeline_utils import DiffusionPipeline
from models.vision.ddim.modeling_ddim import DDIM from models.vision.ddim.modeling_ddim import DDIM
from models.vision.ddpm.modeling_ddpm import DDPM
global_rng = random.Random() global_rng = random.Random()
...@@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase): ...@@ -85,7 +85,6 @@ class ConfigTester(unittest.TestCase):
ConfigMixin.from_config("dummy_path") ConfigMixin.from_config("dummy_path")
def test_save_load(self): def test_save_load(self):
class SampleObject(ConfigMixin): class SampleObject(ConfigMixin):
config_name = "config.json" config_name = "config.json"
...@@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase): ...@@ -153,7 +152,6 @@ class ModelTesterMixin(unittest.TestCase):
class SamplerTesterMixin(unittest.TestCase): class SamplerTesterMixin(unittest.TestCase):
@slow @slow
def test_sample(self): def test_sample(self):
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -163,15 +161,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) (1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
...@@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -201,7 +207,9 @@ class SamplerTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256) assert image.shape == (1, 3, 256, 256)
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
expected_slice = torch.tensor([-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]) expected_slice = torch.tensor(
[-0.1636, -0.1765, -0.1968, -0.1338, -0.1432, -0.1622, -0.1793, -0.2001, -0.2280]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_sample_fast(self): def test_sample_fast(self):
...@@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -212,15 +220,23 @@ class SamplerTesterMixin(unittest.TestCase):
model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device)
# 2. Sample gaussian noise # 2. Sample gaussian noise
image = scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) image = scheduler.sample_noise(
(1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator
)
# 3. Denoise # 3. Denoise
for t in reversed(range(len(scheduler))): for t in reversed(range(len(scheduler))):
# i) define coefficients for time step t # i) define coefficients for time step t
clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t)) clipped_image_coeff = 1 / torch.sqrt(scheduler.get_alpha_prod(t))
clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1) clipped_noise_coeff = torch.sqrt(1 / scheduler.get_alpha_prod(t) - 1)
image_coeff = (1 - scheduler.get_alpha_prod(t - 1)) * torch.sqrt(scheduler.get_alpha(t)) / (1 - scheduler.get_alpha_prod(t)) image_coeff = (
clipped_coeff = torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t)) (1 - scheduler.get_alpha_prod(t - 1))
* torch.sqrt(scheduler.get_alpha(t))
/ (1 - scheduler.get_alpha_prod(t))
)
clipped_coeff = (
torch.sqrt(scheduler.get_alpha_prod(t - 1)) * scheduler.get_beta(t) / (1 - scheduler.get_alpha_prod(t))
)
# ii) predict noise residual # ii) predict noise residual
with torch.no_grad(): with torch.no_grad():
...@@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase): ...@@ -246,7 +262,6 @@ class SamplerTesterMixin(unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32) model = UNetModel(ch=32, ch_mult=(1, 2), num_res_blocks=2, attn_resolutions=(16,), resolution=32)
...@@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -309,5 +324,7 @@ class PipelineTesterMixin(unittest.TestCase):
image_slice = image[0, -1, -3:, -3:].cpu() image_slice = image[0, -1, -3:, -3:].cpu()
assert image.shape == (1, 3, 32, 32) assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor([-0.7688, -0.7690, -0.7597, -0.7660, -0.7713, -0.7531, -0.7009, -0.7098, -0.7350]) expected_slice = torch.tensor(
[-0.7383, -0.7385, -0.7298, -0.7364, -0.7414, -0.7239, -0.6737, -0.6813, -0.7068]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment