Unverified Commit 2bc82d63 authored by hlky's avatar hlky Committed by GitHub
Browse files

DiffusionPipeline mixin `to`+FromOriginalModelMixin/FromSingleFileMixin...

DiffusionPipeline mixin `to`+FromOriginalModelMixin/FromSingleFileMixin `from_single_file` type hint (#10811)

* DiffusionPipeline mixin `to` type hint

* FromOriginalModelMixin from_single_file

* FromSingleFileMixin from_single_file
parent 924f880d
......@@ -19,6 +19,7 @@ import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
from packaging import version
from typing_extensions import Self
from ..utils import deprecate, is_transformers_available, logging
from .single_file_utils import (
......@@ -269,7 +270,7 @@ class FromSingleFileMixin:
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
......
......@@ -19,6 +19,7 @@ from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
......@@ -148,7 +149,7 @@ class FromOriginalModelMixin:
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
r"""
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
is set in evaluation mode (`model.eval()`) by default.
......
......@@ -324,7 +324,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
create_pr=create_pr,
)
def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> Self:
r"""
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment