Commit a5bdb678 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix importing diffusers without transformers installed

parent c4335626
......@@ -178,10 +178,10 @@ else:
from .pipelines import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_note_seq_available()):
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_note_seq_objects import * # noqa F403
from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .pipelines import SpectrogramDiffusionPipeline
......
......@@ -306,7 +306,7 @@ class TextualInversionLoaderMixin:
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
......@@ -334,7 +334,7 @@ class TextualInversionLoaderMixin:
return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
......
......@@ -26,7 +26,6 @@ else:
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .spectrogram_diffusion import SpectrogramDiffusionPipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
......@@ -132,9 +131,9 @@ else:
FlaxStableDiffusionPipeline,
)
try:
if not (is_note_seq_available()):
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_note_seq_objects import * # noqa F403
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import MidiProcessor
from .spectrogram_diffusion import SpectrogramDiffusionPipeline
# flake8: noqa
from ...utils import is_note_seq_available
from ...utils import is_note_seq_available, is_transformers_available
from ...utils import OptionalDependencyNotAvailable
from .notes_encoder import SpectrogramNotesEncoder
from .continous_encoder import SpectrogramContEncoder
from .pipeline_spectrogram_diffusion import (
SpectrogramContEncoder,
SpectrogramDiffusionPipeline,
T5FilmDecoder,
)
if is_note_seq_available():
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .notes_encoder import SpectrogramNotesEncoder
from .continous_encoder import SpectrogramContEncoder
from .pipeline_spectrogram_diffusion import (
SpectrogramContEncoder,
SpectrogramDiffusionPipeline,
T5FilmDecoder,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .midi_utils import MidiProcessor
......@@ -3,15 +3,15 @@ from ..utils import DummyObject, requires_backends
class SpectrogramDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "note_seq"]
_backends = ["transformers", "torch", "note_seq"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "note_seq"])
requires_backends(self, ["transformers", "torch", "note_seq"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "note_seq"])
requires_backends(cls, ["transformers", "torch", "note_seq"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "note_seq"])
requires_backends(cls, ["transformers", "torch", "note_seq"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment