Unverified Commit ada09bd3 authored by Sid Sahai's avatar Sid Sahai Committed by GitHub
Browse files

[Type Hints] DDIM pipelines (#345)



* type hints

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent cc59b056
......@@ -15,7 +15,7 @@
import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
......@@ -31,11 +31,11 @@ class DDIMPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
batch_size=1,
generator=None,
eta=0.0,
num_inference_steps=50,
output_type="pil",
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
......
......@@ -15,7 +15,7 @@
import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
......@@ -30,7 +30,12 @@ class DDPMPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
......
......@@ -10,13 +10,23 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
class LDMTextToImagePipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
def __init__(
self,
vqvae: Union[VQModel, AutoencoderKL],
bert: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
unet: Union[UNet2DModel, UNet2DConditionModel],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
......@@ -618,7 +628,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
class LDMBertModel(LDMBertPreTrainedModel):
def __init__(self, config):
def __init__(self, config: LDMBertConfig):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment