Commit 8b42c7ce authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make all tests pass

parent 66d5a180
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import tqdm import tqdm
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
......
...@@ -37,7 +37,7 @@ class PNDMPipeline(DiffusionPipeline): ...@@ -37,7 +37,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator, generator=generator,
) )
image = image.to(torch_device) image = image.to(torch_device)
......
...@@ -1011,7 +1011,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1011,7 +1011,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_pndm_cifar10(self): def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10" model_id = "google/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True) unet = UNetUnconditionalModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt") scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
...@@ -1072,7 +1072,6 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1072,7 +1072,6 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True) model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
torch.manual_seed(0) torch.manual_seed(0)
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment