"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "a1adceb090bbef2a0da5ffd87e7ea2d2e989474e"
Unverified Commit ada09bd3 authored by Sid Sahai's avatar Sid Sahai Committed by GitHub
Browse files

[Type Hints] DDIM pipelines (#345)



* type hints

* Apply suggestions from code review
Co-authored-by: default avatarAnton Lozhkov <anton@huggingface.co>
parent cc59b056
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import warnings import warnings
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -31,11 +31,11 @@ class DDIMPipeline(DiffusionPipeline): ...@@ -31,11 +31,11 @@ class DDIMPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
batch_size=1, batch_size: int = 1,
generator=None, generator: Optional[torch.Generator] = None,
eta=0.0, eta: float = 0.0,
num_inference_steps=50, num_inference_steps: int = 50,
output_type="pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
**kwargs, **kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import warnings import warnings
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -30,7 +30,12 @@ class DDPMPipeline(DiffusionPipeline): ...@@ -30,7 +30,12 @@ class DDPMPipeline(DiffusionPipeline):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs self,
batch_size: int = 1,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs: if "torch_device" in kwargs:
device = kwargs.pop("torch_device") device = kwargs.pop("torch_device")
......
...@@ -10,13 +10,23 @@ from transformers.activations import ACT2FN ...@@ -10,13 +10,23 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging from transformers.utils import logging
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
class LDMTextToImagePipeline(DiffusionPipeline): class LDMTextToImagePipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler): def __init__(
self,
vqvae: Union[VQModel, AutoencoderKL],
bert: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
unet: Union[UNet2DModel, UNet2DConditionModel],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
...@@ -618,7 +628,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel): ...@@ -618,7 +628,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
class LDMBertModel(LDMBertPreTrainedModel): class LDMBertModel(LDMBertPreTrainedModel):
def __init__(self, config): def __init__(self, config: LDMBertConfig):
super().__init__(config) super().__init__(config)
self.model = LDMBertEncoder(config) self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size) self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment