Unverified Commit 9c3820d0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Big Model Renaming (#109)

* up

* change model name

* renaming

* more changes

* up

* up

* up

* save checkpoint

* finish api / naming

* finish config renaming

* rename all weights

* finish really
parent 13e37cab
...@@ -14,7 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -14,7 +14,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size img_size = self.model.config.sample_size
shape = (1, 3, img_size, img_size) shape = (1, 3, img_size, img_size)
model = self.model.to(device) model = self.model.to(device)
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
import math import math
import tempfile import tempfile
...@@ -23,7 +22,7 @@ import numpy as np ...@@ -23,7 +22,7 @@ import numpy as np
import torch import torch
import PIL import PIL
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import UNet2DConditionModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMPipeline, DDIMPipeline,
...@@ -36,7 +35,7 @@ from diffusers import ( ...@@ -36,7 +35,7 @@ from diffusers import (
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline, ScoreSdeVePipeline,
ScoreSdeVeScheduler, ScoreSdeVeScheduler,
UNetUnconditionalModel, UNet2DModel,
VQModel, VQModel,
) )
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
...@@ -271,7 +270,7 @@ class ModelTesterMixin: ...@@ -271,7 +270,7 @@ class ModelTesterMixin:
class UnetModelTests(ModelTesterMixin, unittest.TestCase): class UnetModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel model_class = UNet2DModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -294,14 +293,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -294,14 +293,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_channels": (32, 64), "block_out_channels": (32, 64),
"down_blocks": ("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), "down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_blocks": ("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), "up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"num_head_channels": None, "attention_head_dim": None,
"out_channels": 3, "out_channels": 3,
"in_channels": 3, "in_channels": 3,
"num_res_blocks": 2, "layers_per_block": 2,
"image_size": 32, "sample_size": 32,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
...@@ -309,14 +308,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -309,14 +308,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints # TODO(Patrick) - Re-add this test after having correctly added the final VE checkpoints
# def test_output_pretrained(self): # def test_output_pretrained(self):
# model = UNetUnconditionalModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet") # model = UNet2DModel.from_pretrained("fusing/ddpm_dummy_update", subfolder="unet")
# model.eval() # model.eval()
# #
# torch.manual_seed(0) # torch.manual_seed(0)
# if torch.cuda.is_available(): # if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0) # torch.cuda.manual_seed_all(0)
# #
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) # noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# time_step = torch.tensor([10]) # time_step = torch.tensor([10])
# #
# with torch.no_grad(): # with torch.no_grad():
...@@ -330,7 +329,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -330,7 +329,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel model_class = UNet2DModel
@property @property
def dummy_input(self): def dummy_input(self):
...@@ -353,23 +352,23 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -353,23 +352,23 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"image_size": 32, "sample_size": 32,
"in_channels": 4, "in_channels": 4,
"out_channels": 4, "out_channels": 4,
"num_res_blocks": 2, "layers_per_block": 2,
"block_channels": (32, 64), "block_out_channels": (32, 64),
"num_head_channels": 32, "attention_head_dim": 32,
"conv_resample": True, "down_block_types": ("DownBlock2D", "DownBlock2D"),
"down_blocks": ("UNetResDownBlock2D", "UNetResDownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"),
"up_blocks": ("UNetResUpBlock2D", "UNetResUpBlock2D"),
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained( model, loading_info = UNet2DModel.from_pretrained(
"fusing/unet-ldm-dummy-update", output_loading_info=True "/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True
) )
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -379,14 +378,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -379,14 +378,14 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNetUnconditionalModel.from_pretrained("fusing/unet-ldm-dummy-update") model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0]) time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad(): with torch.no_grad():
...@@ -409,7 +408,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -409,7 +408,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
# if torch.cuda.is_available(): # if torch.cuda.is_available():
# torch.cuda.manual_seed_all(0) # torch.cuda.manual_seed_all(0)
# #
# noise = torch.randn(1, model.config.in_channels, model.config.image_size, model.config.image_size) # noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
# context = torch.ones((1, 16, 64), dtype=torch.float32) # context = torch.ones((1, 16, 64), dtype=torch.float32)
# time_step = torch.tensor([10] * noise.shape[0]) # time_step = torch.tensor([10] * noise.shape[0])
# #
...@@ -426,13 +425,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -426,13 +425,12 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNetUnconditionalModel model_class = UNet2DModel
@property @property
def dummy_input(self): def dummy_input(self, sizes=(32, 32)):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device) time_step = torch.tensor(batch_size * [10]).to(torch_device)
...@@ -449,44 +447,47 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -449,44 +447,47 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self): def prepare_init_args_and_inputs_for_common(self):
init_dict = { init_dict = {
"block_channels": [32, 64, 64, 64], "block_out_channels": [32, 64, 64, 64],
"in_channels": 3, "in_channels": 3,
"num_res_blocks": 1, "layers_per_block": 1,
"out_channels": 3, "out_channels": 3,
"time_embedding_type": "fourier", "time_embedding_type": "fourier",
"resnet_eps": 1e-6, "norm_eps": 1e-6,
"mid_block_scale_factor": math.sqrt(2.0), "mid_block_scale_factor": math.sqrt(2.0),
"resnet_num_groups": None, "norm_num_groups": None,
"down_blocks": [ "down_block_types": [
"UNetResSkipDownBlock2D", "SkipDownBlock2D",
"UNetResAttnSkipDownBlock2D", "AttnSkipDownBlock2D",
"UNetResSkipDownBlock2D", "SkipDownBlock2D",
"UNetResSkipDownBlock2D", "SkipDownBlock2D",
], ],
"up_blocks": [ "up_block_types": [
"UNetResSkipUpBlock2D", "SkipUpBlock2D",
"UNetResSkipUpBlock2D", "SkipUpBlock2D",
"UNetResAttnSkipUpBlock2D", "AttnSkipUpBlock2D",
"UNetResSkipUpBlock2D", "SkipUpBlock2D",
], ],
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNetUnconditionalModel.from_pretrained( model, loading_info = UNet2DModel.from_pretrained(
"fusing/ncsnpp-ffhq-ve-dummy-update", output_loading_info=True "/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True
) )
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device) model.to(torch_device)
image = model(**self.dummy_input) inputs = self.dummy_input
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained_ve_mid(self): def test_output_pretrained_ve_mid(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-celebahq-256") model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -511,7 +512,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -511,7 +512,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self): def test_output_pretrained_ve_large(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update")
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -540,10 +541,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -540,10 +541,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
model_class = VQModel model_class = VQModel
@property @property
def dummy_input(self): def dummy_input(self, sizes=(32, 32)):
batch_size = 4 batch_size = 4
num_channels = 3 num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
...@@ -570,7 +570,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -570,7 +570,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
"embed_dim": 3, "embed_dim": 3,
"sane_index_shape": False, "sane_index_shape": False,
"ch_mult": (1,), "ch_mult": (1,),
"dropout": 0.0,
"double_z": False, "double_z": False,
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
...@@ -583,7 +582,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -583,7 +582,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
pass pass
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) model, loading_info = VQModel.from_pretrained(
"/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -593,7 +594,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -593,7 +594,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = VQModel.from_pretrained("fusing/vqgan-dummy") model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -654,7 +655,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -654,7 +655,9 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
pass pass
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) model, loading_info = AutoencoderKL.from_pretrained(
"/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True
)
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0) self.assertEqual(len(loading_info["missing_keys"]), 0)
...@@ -664,7 +667,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -664,7 +667,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
def test_output_pretrained(self): def test_output_pretrained(self):
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy")
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -685,14 +688,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): ...@@ -685,14 +688,14 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNetUnconditionalModel( model = UNet2DModel(
block_channels=(32, 64), block_out_channels=(32, 64),
num_res_blocks=2, layers_per_block=2,
image_size=32, sample_size=32,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
down_blocks=("UNetResDownBlock2D", "UNetResAttnDownBlock2D"), down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_blocks=("UNetResAttnUpBlock2D", "UNetResUpBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"),
) )
schedular = DDPMScheduler(num_train_timesteps=10) schedular = DDPMScheduler(num_train_timesteps=10)
...@@ -712,7 +715,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -712,7 +715,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32" model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
ddpm = DDPMPipeline.from_pretrained(model_path) ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
...@@ -730,7 +733,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -730,7 +733,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_output_format(self): def test_output_format(self):
model_path = "google/ddpm-cifar10-32" model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path)
...@@ -751,9 +754,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -751,9 +754,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddpm_cifar10(self): def test_ddpm_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDPMScheduler.from_config(model_id) scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
...@@ -770,9 +773,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -770,9 +773,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddim_lsun(self): def test_ddim_lsun(self):
model_id = "google/ddpm-ema-bedroom-256" model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256"
unet = UNetUnconditionalModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler.from_config(model_id) scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
...@@ -788,9 +791,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -788,9 +791,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ddim_cifar10(self): def test_ddim_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = DDIMScheduler(tensor_format="pt") scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
...@@ -806,9 +809,9 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -806,9 +809,9 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
unet = UNetUnconditionalModel.from_pretrained(model_id) unet = UNet2DModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
...@@ -823,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -823,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img(self): def test_ldm_text2img(self):
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -839,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -839,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
ldm = LatentDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -853,13 +856,13 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -853,13 +856,13 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256") model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
...@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -874,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment