"tests/vscode:/vscode.git/clone" did not exist on "625e82ce624a62e5f4bab22897c7565bbcfb9547"
Unverified Commit b6e0b016 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Lazy Import for Diffusers (#4829)



* initial commit

* move modules to import struct

* add dummy objects and _LazyModule

* add lazy import to schedulers

* clean up unused imports

* lazy import on models module

* lazy import for schedulers module

* add lazy import to pipelines module

* lazy import altdiffusion

* lazy import audio diffusion

* lazy import audioldm

* lazy import consistency model

* lazy import controlnet

* lazy import dance diffusion ddim ddpm

* lazy import deepfloyd

* lazy import kandinksy

* lazy imports

* lazy import semantic diffusion

* lazy imports

* lazy import stable diffusion

* move sd output to its own module

* clean up

* lazy import t2iadapter

* lazy import unclip

* lazy import versatile and vq diffsuion

* lazy import vq diffusion

* helper to fetch objects from modules

* lazy import sdxl

* lazy import txt2vid

* lazy import stochastic karras

* fix model imports

* fix bug

* lazy import

* clean up

* clean up

* fixes for tests

* fixes for tests

* clean up

* remove import of torch_utils from utils module

* clean up

* clean up

* fix mistake import statement

* dedicated modules for exporting and loading

* remove testing utils from utils module

* fixes from  merge conflicts

* Update src/diffusers/pipelines/kandinsky2_2/__init__.py

* fix docs

* fix alt diffusion copied from

* fix check dummies

* fix more docs

* remove accelerate import from utils module

* add type checking

* make style

* fix check dummies

* remove torch import from xformers check

* clean up error message

* fixes after upstream merges

* dummy objects fix

* fix tests

* remove unused module import

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 88735249
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from ..utils import ( from ..utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available, is_flax_available,
is_scipy_available, is_scipy_available,
is_torch_available, is_torch_available,
...@@ -22,38 +23,49 @@ from ..utils import ( ...@@ -22,38 +23,49 @@ from ..utils import (
) )
_import_structure = {}
_dummy_modules = {}
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403 from ..utils import dummy_pt_objects # noqa F403
modules = {}
for name in dir(dummy_pt_objects):
if (not name.endswith("Scheduler")) or name.startswith("_"):
continue
modules[name] = getattr(dummy_pt_objects, name)
_dummy_modules.update(modules)
else: else:
from .scheduling_consistency_models import CMStochasticIterativeScheduler _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
from .scheduling_ddim import DDIMScheduler _import_structure["scheduling_ddim"] = ["DDIMScheduler"]
from .scheduling_ddim_inverse import DDIMInverseScheduler _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
from .scheduling_ddim_parallel import DDIMParallelScheduler _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
from .scheduling_ddpm import DDPMScheduler _import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
from .scheduling_ddpm_parallel import DDPMParallelScheduler _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
from .scheduling_deis_multistep import DEISMultistepScheduler _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler _import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler _import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
from .scheduling_euler_discrete import EulerDiscreteScheduler _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
from .scheduling_heun_discrete import HeunDiscreteScheduler _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
from .scheduling_ipndm import IPNDMScheduler _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
from .scheduling_karras_ve import KarrasVeScheduler _import_structure["scheduling_pndm"] = ["PNDMScheduler"]
from .scheduling_pndm import PNDMScheduler _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
from .scheduling_repaint import RePaintScheduler _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
from .scheduling_sde_ve import ScoreSdeVeScheduler _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
from .scheduling_sde_vp import ScoreSdeVpScheduler _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
from .scheduling_unclip import UnCLIPScheduler _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
from .scheduling_unipc_multistep import UniPCMultistepScheduler _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
from .scheduling_vq_diffusion import VQDiffusionScheduler _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
try: try:
if not is_flax_available(): if not is_flax_available():
...@@ -61,33 +73,59 @@ try: ...@@ -61,33 +73,59 @@ try:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_flax_objects import * # noqa F403
else: else:
from .scheduling_ddim_flax import FlaxDDIMScheduler _import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
from .scheduling_ddpm_flax import FlaxDDPMScheduler _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
from .scheduling_pndm_flax import FlaxPNDMScheduler _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler _import_structure["scheduling_sde_ve_flax"] = ["FlaxScoreSdeVeScheduler"]
from .scheduling_utils_flax import ( _import_structure["scheduling_utils_flax"] = [
FlaxKarrasDiffusionSchedulers, "FlaxKarrasDiffusionSchedulers",
FlaxSchedulerMixin, "FlaxSchedulerMixin",
FlaxSchedulerOutput, "FlaxSchedulerOutput",
broadcast_to_shape_from_left, "broadcast_to_shape_from_left",
) ]
try: try:
if not (is_torch_available() and is_scipy_available()): if not (is_torch_available() and is_scipy_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 from ..utils import dummy_torch_and_scipy_objects # noqa F403
modules = {}
for name in dir(dummy_torch_and_scipy_objects):
if (not name.endswith("Scheduler")) or name.startswith("_"):
continue
modules[name] = getattr(dummy_torch_and_scipy_objects, name)
_dummy_modules.update(modules)
else: else:
from .scheduling_lms_discrete import LMSDiscreteScheduler _import_structure["scheduling_lms_discrete"] = ["LMSDiscreteScheduler"]
try: try:
if not (is_torch_available() and is_torchsde_available()): if not (is_torch_available() and is_torchsde_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 from ..utils import dummy_torch_and_torchsde_objects # noqa F403
modules = {}
for name in dir(dummy_torch_and_torchsde_objects):
if (not name.endswith("Scheduler")) or name.startswith("_"):
continue
modules[name] = getattr(dummy_torch_and_torchsde_objects, name)
_dummy_modules.update(modules)
else: else:
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"]
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
for name, value in _dummy_modules.items():
setattr(sys.modules[__name__], name, value)
...@@ -19,7 +19,8 @@ import numpy as np ...@@ -19,7 +19,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -23,7 +23,8 @@ import numpy as np ...@@ -23,7 +23,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -23,7 +23,8 @@ import numpy as np ...@@ -23,7 +23,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union ...@@ -22,7 +22,8 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
......
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor from ..utils import BaseOutput, logging
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
......
...@@ -20,7 +20,7 @@ import numpy as np ...@@ -20,7 +20,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,8 @@ from typing import Optional, Tuple, Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
......
...@@ -20,7 +20,7 @@ from typing import Union ...@@ -20,7 +20,7 @@ from typing import Union
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import randn_tensor from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, randn_tensor from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import SchedulerMixin
......
...@@ -18,7 +18,6 @@ import os ...@@ -18,7 +18,6 @@ import os
from packaging import version from packaging import version
from .. import __version__ from .. import __version__
from .accelerate_utils import apply_forward_hook
from .constants import ( from .constants import (
CONFIG_NAME, CONFIG_NAME,
DEPRECATED_REVISION_ARGS, DEPRECATED_REVISION_ARGS,
...@@ -35,6 +34,7 @@ from .constants import ( ...@@ -35,6 +34,7 @@ from .constants import (
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
from .hub_utils import ( from .hub_utils import (
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
PushToHubMixin, PushToHubMixin,
...@@ -52,6 +52,8 @@ from .import_utils import ( ...@@ -52,6 +52,8 @@ from .import_utils import (
USE_TORCH, USE_TORCH,
DummyObject, DummyObject,
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_accelerate_available, is_accelerate_available,
is_accelerate_version, is_accelerate_version,
is_bs4_available, is_bs4_available,
...@@ -78,32 +80,10 @@ from .import_utils import ( ...@@ -78,32 +80,10 @@ from .import_utils import (
is_xformers_available, is_xformers_available,
requires_backends, requires_backends,
) )
from .loading_utils import load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .torch_utils import is_compiled_module, randn_tensor
if is_torch_available():
from .testing_utils import (
floats_tensor,
load_hf_numpy,
load_image,
load_numpy,
load_pt,
nightly,
parse_flag_from_env,
print_tensor_test,
require_torch_2,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,
)
from .torch_utils import maybe_allow_in_graph
from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
logger = get_logger(__name__) logger = get_logger(__name__)
......
import io
import random
import struct
import tempfile
from contextlib import contextmanager
from typing import List
import numpy as np
import PIL.Image
import PIL.ImageOps
from .import_utils import (
BACKENDS_MAPPING,
is_opencv_available,
)
from .logging import get_logger
global_rng = random.Random()
logger = get_logger(__name__)
@contextmanager
def buffered_writer(raw_f):
f = io.BufferedWriter(raw_f)
yield f
f.flush()
def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
if output_gif_path is None:
output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
image[0].save(
output_gif_path,
save_all=True,
append_images=image[1:],
optimize=False,
duration=100,
loop=0,
)
return output_gif_path
def export_to_ply(mesh, output_ply_path: str = None):
"""
Write a PLY file for a mesh.
"""
if output_ply_path is None:
output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
coords = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
with buffered_writer(open(output_ply_path, "wb")) as f:
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
if rgb is not None:
f.write(b"property uchar red\n")
f.write(b"property uchar green\n")
f.write(b"property uchar blue\n")
if faces is not None:
f.write(bytes(f"element face {len(faces)}\n", "ascii"))
f.write(b"property list uchar int vertex_index\n")
f.write(b"end_header\n")
if rgb is not None:
rgb = (rgb * 255.499).round().astype(int)
vertices = [
(*coord, *rgb)
for coord, rgb in zip(
coords.tolist(),
rgb.tolist(),
)
]
format = struct.Struct("<3f3B")
for item in vertices:
f.write(format.pack(*item))
else:
format = struct.Struct("<3f")
for vertex in coords.tolist():
f.write(format.pack(*vertex))
if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist():
f.write(format.pack(len(tri), *tri))
return output_ply_path
def export_to_obj(mesh, output_obj_path: str = None):
if output_obj_path is None:
output_obj_path = tempfile.NamedTemporaryFile(suffix=".obj").name
verts = mesh.verts.detach().cpu().numpy()
faces = mesh.faces.cpu().numpy()
vertex_colors = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
vertices = [
"{} {} {} {} {} {}".format(*coord, *color) for coord, color in zip(verts.tolist(), vertex_colors.tolist())
]
faces = ["f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) for tri in faces.tolist()]
combined_data = ["v " + vertex for vertex in vertices] + faces
with open(output_obj_path, "w") as f:
f.writelines("\n".join(combined_data))
def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str:
if is_opencv_available():
import cv2
else:
raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
h, w, c = video_frames[0].shape
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
for i in range(len(video_frames)):
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
video_writer.write(img)
return output_video_path
...@@ -19,7 +19,9 @@ import operator as op ...@@ -19,7 +19,9 @@ import operator as op
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from typing import Union from itertools import chain
from types import ModuleType
from typing import Any, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401 from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging import version from packaging import version
...@@ -219,10 +221,10 @@ _xformers_available = importlib.util.find_spec("xformers") is not None ...@@ -219,10 +221,10 @@ _xformers_available = importlib.util.find_spec("xformers") is not None
try: try:
_xformers_version = importlib_metadata.version("xformers") _xformers_version = importlib_metadata.version("xformers")
if _torch_available: if _torch_available:
import torch _torch_version = importlib_metadata.version("torch")
if version.Version(_torch_version) < version.Version("1.12"):
raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12")
if version.Version(torch.__version__) < version.Version("1.12"):
raise ValueError("PyTorch should be >= 1.12")
logger.debug(f"Successfully imported xformers version {_xformers_version}") logger.debug(f"Successfully imported xformers version {_xformers_version}")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_xformers_available = False _xformers_available = False
...@@ -647,5 +649,85 @@ def is_k_diffusion_version(operation: str, version: str): ...@@ -647,5 +649,85 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version) return compare_versions(parse(_k_diffusion_version), operation, version)
def get_objects_from_module(module):
"""
Args:
Returns a dict of object names and values in a module, while skipping private/internal objects
module (ModuleType):
Module to extract the objects from.
Returns:
dict: Dictionary of object names and corresponding values
"""
objects = {}
for name in dir(module):
if name.startswith("_"):
continue
objects[name] = getattr(module, name)
return objects
class OptionalDependencyNotAvailable(BaseException): class OptionalDependencyNotAvailable(BaseException):
"""An error indicating that an optional dependency of Diffusers was not found in the environment.""" """An error indicating that an optional dependency of Diffusers was not found in the environment."""
class _LazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
# Needed for autocompletion in an IDE
def __dir__(self):
result = super().__dir__()
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> Any:
if name in self._objects:
return self._objects[name]
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment