Commit 8b42c7ce authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make all tests pass

parent 66d5a180
......@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import tqdm
import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
......
......@@ -37,7 +37,7 @@ class PNDMPipeline(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
)
image = image.to(torch_device)
......
......@@ -1011,7 +1011,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_pndm_cifar10(self):
model_id = "google/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
unet = UNetUnconditionalModel.from_pretrained(model_id)
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
......@@ -1072,7 +1072,6 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("fusing/ffhq_ncsnpp", sde=True)
model = UNetUnconditionalModel.from_pretrained("google/ffhq_ncsnpp")
torch.manual_seed(0)
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment