Unverified Commit 5311f564 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Final fixes (#118)

final fixes before release
parent 3b7f514a
...@@ -6,19 +6,20 @@ from tqdm.auto import tqdm ...@@ -6,19 +6,20 @@ from tqdm.auto import tqdm
class ScoreSdeVePipeline(DiffusionPipeline): class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
self.register_modules(model=model, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"): def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
if torch_device is None: if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
img_size = self.model.config.sample_size img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size) shape = (batch_size, 3, img_size, img_size)
model = self.model.to(torch_device) model = self.unet.to(torch_device)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device) sample = sample.to(torch_device)
...@@ -31,7 +32,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -31,7 +32,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step # correction step
for _ in range(self.scheduler.correct_steps): for _ in range(self.scheduler.correct_steps):
model_output = self.model(sample, sigma_t)["sample"] model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
# prediction step # prediction step
...@@ -40,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): ...@@ -40,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample = sample.clamp(0, 1) sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy() sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil": if output_type == "pil":
sample = self.numpy_to_pil(sample) sample = self.numpy_to_pil(sample)
......
...@@ -848,15 +848,12 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -848,15 +848,12 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_score_sde_ve_pipeline(self): def test_score_sde_ve_pipeline(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256") model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)
torch.manual_seed(0) scheduler = ScoreSdeVeScheduler.from_config(model_id)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
torch.manual_seed(0) torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment