Unverified Commit f6f1ec3a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

allow loading ddpm models into ddim (#1932)

parent beb932c5
...@@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ...schedulers import DDIMScheduler
from ...utils import deprecate, randn_tensor from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -34,6 +35,10 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -34,6 +35,10 @@ class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, scheduler): def __init__(self, unet, scheduler):
super().__init__() super().__init__()
# make sure scheduler can always be converted to DDIM
scheduler = DDIMScheduler.from_config(scheduler.config)
self.register_modules(unet=unet, scheduler=scheduler) self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment