Unverified Commit 7aa6af11 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Refactor] Move testing utils out of src (#12238)

* update

* update

* update

* update

* update

* merge main

* Revert "merge main"

This reverts commit 65efbcead58644b31596ed2d714f7cee0e0238d3.
parent 87b800e1
...@@ -25,6 +25,11 @@ from os.path import abspath, dirname, join ...@@ -25,6 +25,11 @@ from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src")) git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
sys.path.insert(1, git_repo_path) sys.path.insert(1, git_repo_path)
# Add parent directory to path so we can import from tests
repo_root = abspath(dirname(dirname(__file__)))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
# silence FutureWarning warnings in tests since often we can't act on them until # silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality # they become normal warnings - i.e. the tests still need to test the current functionality
...@@ -32,13 +37,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning) ...@@ -32,13 +37,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser): def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared from tests.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser) pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter): def pytest_terminal_summary(terminalreporter):
from diffusers.utils.testing_utils import pytest_terminal_summary_main from tests.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports") make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports: if make_reports:
......
...@@ -24,6 +24,8 @@ import math ...@@ -24,6 +24,8 @@ import math
import os import os
import random import random
import shutil import shutil
# Add repo root to path to import from tests
from pathlib import Path from pathlib import Path
import accelerate import accelerate
...@@ -54,8 +56,7 @@ from diffusers.optimization import get_scheduler ...@@ -54,8 +56,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.testing_utils import backend_empty_cache from diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available(): if is_wandb_available():
......
...@@ -24,12 +24,18 @@ import tempfile ...@@ -24,12 +24,18 @@ import tempfile
import torch import torch
from diffusers import VQModel from diffusers import VQModel
from diffusers.utils.testing_utils import require_timm
# Add parent directories to path to import from tests
sys.path.append("..") sys.path.append("..")
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
from tests.testing_utils import require_timm # noqa
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
......
...@@ -66,7 +66,10 @@ else: ...@@ -66,7 +66,10 @@ else:
global_rng = random.Random() global_rng = random.Random()
logger = get_logger(__name__) logger = get_logger(__name__)
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
)
_required_peft_version = is_peft_available() and version.parse( _required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5") ) > version.parse("0.5")
...@@ -801,10 +804,9 @@ def export_to_ply(mesh, output_ply_path: str = None): ...@@ -801,10 +804,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
f.write(format.pack(*vertex)) f.write(format.pack(*vertex))
if faces is not None: if faces is not None:
format = struct.Struct("<B3I")
for tri in faces.tolist(): for tri in faces.tolist():
f.write(format.pack(len(tri), *tri)) f.write(format.pack(len(tri), *tri))
format = struct.Struct("<B3I")
return output_ply_path return output_ply_path
...@@ -1144,23 +1146,23 @@ def enable_full_determinism(): ...@@ -1144,23 +1146,23 @@ def enable_full_determinism():
Helper function for reproducible behavior during distributed training. See Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
""" """
# Enable PyTorch deterministic mode. This potentially requires either the environment from .torch_utils import enable_full_determinism as _enable_full_determinism
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode logger.warning(
torch.backends.cudnn.deterministic = True "enable_full_determinism has been moved to diffusers.utils.torch_utils. "
torch.backends.cudnn.benchmark = False "Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
torch.backends.cuda.matmul.allow_tf32 = False )
return _enable_full_determinism()
def disable_full_determinism(): def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0" from .torch_utils import disable_full_determinism as _disable_full_determinism
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False) logger.warning(
"disable_full_determinism has been moved to diffusers.utils.torch_utils. "
"Importing from diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _disable_full_determinism()
# Utils for custom and alternative accelerator devices # Utils for custom and alternative accelerator devices
...@@ -1282,43 +1284,85 @@ def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], ...@@ -1282,43 +1284,85 @@ def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable],
# These are callables which automatically dispatch the function specific to the accelerator # These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int): def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) from .torch_utils import backend_manual_seed as _backend_manual_seed
logger.warning(
"backend_manual_seed has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_manual_seed(device, seed)
def backend_synchronize(device: str): def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE) from .torch_utils import backend_synchronize as _backend_synchronize
logger.warning(
"backend_synchronize has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_synchronize(device)
def backend_empty_cache(device: str): def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) from .torch_utils import backend_empty_cache as _backend_empty_cache
logger.warning(
"backend_empty_cache has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_empty_cache(device)
def backend_device_count(device: str): def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) from .torch_utils import backend_device_count as _backend_device_count
logger.warning(
"backend_device_count has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_device_count(device)
def backend_reset_peak_memory_stats(device: str): def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS) from .torch_utils import backend_reset_peak_memory_stats as _backend_reset_peak_memory_stats
logger.warning(
"backend_reset_peak_memory_stats has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_reset_peak_memory_stats(device)
def backend_reset_max_memory_allocated(device: str): def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED) from .torch_utils import backend_reset_max_memory_allocated as _backend_reset_max_memory_allocated
logger.warning(
"backend_reset_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_reset_max_memory_allocated(device)
def backend_max_memory_allocated(device: str): def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED) from .torch_utils import backend_max_memory_allocated as _backend_max_memory_allocated
logger.warning(
"backend_max_memory_allocated has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_max_memory_allocated(device)
# These are callables which return boolean behaviour flags and can be used to specify some # These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported. # device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str): def backend_supports_training(device: str):
if not is_torch_available(): from .torch_utils import backend_supports_training as _backend_supports_training
return False
if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"
return BACKEND_SUPPORTS_TRAINING[device] logger.warning(
"backend_supports_training has been moved to diffusers.utils.torch_utils. "
"diffusers.utils.testing_utils is deprecated and will be removed in a future version."
)
return _backend_supports_training(device)
# Guard for when Torch is not available # Guard for when Torch is not available
......
...@@ -16,7 +16,8 @@ PyTorch utilities: Utilities related to PyTorch ...@@ -16,7 +16,8 @@ PyTorch utilities: Utilities related to PyTorch
""" """
import functools import functools
from typing import List, Optional, Tuple, Union import os
from typing import Callable, Dict, List, Optional, Tuple, Union
from . import logging from . import logging
from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
...@@ -26,6 +27,56 @@ if is_torch_available(): ...@@ -26,6 +27,56 @@ if is_torch_available():
import torch import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift from torch.fft import fftn, fftshift, ifftn, ifftshift
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
BACKEND_EMPTY_CACHE = {
"cuda": torch.cuda.empty_cache,
"xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}
BACKEND_SYNCHRONIZE = {
"cuda": torch.cuda.synchronize,
"xpu": getattr(torch.xpu, "synchronize", None),
"cpu": None,
"mps": None,
"default": None,
}
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try: try:
...@@ -36,6 +87,62 @@ except (ImportError, ModuleNotFoundError): ...@@ -36,6 +87,62 @@ except (ImportError, ModuleNotFoundError):
return cls return cls
# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)
fn = dispatch_table[device]
# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
if not callable(fn):
return fn
return fn(*args, **kwargs)
# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
def backend_synchronize(device: str):
return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
if not is_torch_available():
return False
if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"
return BACKEND_SUPPORTS_TRAINING[device]
def randn_tensor( def randn_tensor(
shape: Union[Tuple, List], shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
...@@ -197,3 +304,31 @@ def device_synchronize(device_type: Optional[str] = None): ...@@ -197,3 +304,31 @@ def device_synchronize(device_type: Optional[str] = None):
device_type = get_device() device_type = get_device()
device_mod = getattr(torch, device_type, torch.cuda) device_mod = getattr(torch, device_type, torch.cuda)
device_mod.synchronize() device_mod.synchronize()
def enable_full_determinism():
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)
if is_torch_available():
torch_device = get_device()
...@@ -35,13 +35,13 @@ def pytest_configure(config): ...@@ -35,13 +35,13 @@ def pytest_configure(config):
def pytest_addoption(parser): def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared from .testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser) pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter): def pytest_terminal_summary(terminalreporter):
from diffusers.utils.testing_utils import pytest_terminal_summary_main from .testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports") make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports: if make_reports:
......
...@@ -24,7 +24,8 @@ from diffusers.models import ModelMixin ...@@ -24,7 +24,8 @@ from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions from diffusers.utils.import_utils import compare_versions
from diffusers.utils.testing_utils import (
from ..testing_utils import (
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated, backend_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
......
...@@ -20,7 +20,8 @@ import torch ...@@ -20,7 +20,8 @@ import torch
from diffusers.hooks import HookRegistry, ModelHook from diffusers.hooks import HookRegistry, ModelHook
from diffusers.training_utils import free_memory from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger from diffusers.utils.logging import get_logger
from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..testing_utils import CaptureLogger, torch_device
logger = get_logger(__name__) # pylint: disable=invalid-name logger = get_logger(__name__) # pylint: disable=invalid-name
......
...@@ -23,7 +23,8 @@ from diffusers import ( ...@@ -23,7 +23,8 @@ from diffusers import (
AuraFlowTransformer2DModel, AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteScheduler,
) )
from diffusers.utils.testing_utils import (
from ..testing_utils import (
floats_tensor, floats_tensor,
is_peft_available, is_peft_available,
require_peft_backend, require_peft_backend,
...@@ -35,7 +36,7 @@ if is_peft_available(): ...@@ -35,7 +36,7 @@ if is_peft_available():
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -26,7 +26,8 @@ from diffusers import ( ...@@ -26,7 +26,8 @@ from diffusers import (
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
) )
from diffusers.utils.testing_utils import (
from ..testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
...@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import ( ...@@ -35,7 +36,7 @@ from diffusers.utils.testing_utils import (
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -22,7 +22,8 @@ from parameterized import parameterized ...@@ -22,7 +22,8 @@ from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import (
from ..testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
...@@ -33,7 +34,7 @@ from diffusers.utils.testing_utils import ( ...@@ -33,7 +34,7 @@ from diffusers.utils.testing_utils import (
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
class TokenizerWrapper: class TokenizerWrapper:
......
...@@ -28,7 +28,8 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderM ...@@ -28,7 +28,8 @@ from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderM
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import load_image, logging from diffusers.utils import load_image, logging
from diffusers.utils.testing_utils import (
from ..testing_utils import (
CaptureLogger, CaptureLogger,
backend_empty_cache, backend_empty_cache,
floats_tensor, floats_tensor,
...@@ -48,7 +49,7 @@ if is_peft_available(): ...@@ -48,7 +49,7 @@ if is_peft_available():
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -26,7 +26,8 @@ from diffusers import ( ...@@ -26,7 +26,8 @@ from diffusers import (
HunyuanVideoPipeline, HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformer3DModel,
) )
from diffusers.utils.testing_utils import (
from ..testing_utils import (
Expectations, Expectations,
backend_empty_cache, backend_empty_cache,
floats_tensor, floats_tensor,
...@@ -42,7 +43,7 @@ from diffusers.utils.testing_utils import ( ...@@ -42,7 +43,7 @@ from diffusers.utils.testing_utils import (
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -24,12 +24,13 @@ from diffusers import ( ...@@ -24,12 +24,13 @@ from diffusers import (
LTXPipeline, LTXPipeline,
LTXVideoTransformer3DModel, LTXVideoTransformer3DModel,
) )
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from ..testing_utils import floats_tensor, require_peft_backend
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -26,12 +26,13 @@ from diffusers import ( ...@@ -26,12 +26,13 @@ from diffusers import (
Lumina2Pipeline, Lumina2Pipeline,
Lumina2Transformer2DModel, Lumina2Transformer2DModel,
) )
from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
from ..testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.testing_utils import (
from ..testing_utils import (
floats_tensor, floats_tensor,
require_peft_backend, require_peft_backend,
skip_mps, skip_mps,
...@@ -28,7 +29,7 @@ from diffusers.utils.testing_utils import ( ...@@ -28,7 +29,7 @@ from diffusers.utils.testing_utils import (
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -24,12 +24,13 @@ from diffusers import ( ...@@ -24,12 +24,13 @@ from diffusers import (
QwenImagePipeline, QwenImagePipeline,
QwenImageTransformer2DModel, QwenImageTransformer2DModel,
) )
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from ..testing_utils import floats_tensor, require_peft_backend
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
...@@ -19,12 +19,13 @@ import torch ...@@ -19,12 +19,13 @@ import torch
from transformers import Gemma2Model, GemmaTokenizer from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
from ..testing_utils import floats_tensor, require_peft_backend
sys.path.append(".") sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402 from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment