Unverified Commit 7e29b747 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Check k-diffusion version is at least 0.0.12 (#2022)

* Check k-diffusion version is at least 0.0.12

* make style
parent a43bdd01
...@@ -91,7 +91,7 @@ _deps = [ ...@@ -91,7 +91,7 @@ _deps = [
"isort>=5.5.4", "isort>=5.5.4",
"jax>=0.2.8,!=0.3.2", "jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65", "jaxlib>=0.1.65",
"k-diffusion", "k-diffusion>=0.0.12",
"librosa", "librosa",
"modelcards>=0.1.4", "modelcards>=0.1.4",
"numpy", "numpy",
......
...@@ -6,6 +6,7 @@ from .utils import ( ...@@ -6,6 +6,7 @@ from .utils import (
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available, is_librosa_available,
is_onnx_available, is_onnx_available,
is_scipy_available, is_scipy_available,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
deps = { deps = {
"Pillow": "Pillow", "Pillow": "Pillow",
"accelerate": "accelerate>=0.11.0", "accelerate": "accelerate>=0.11.0",
"black": "black==22.8", "black": "black==22.12",
"datasets": "datasets", "datasets": "datasets",
"filelock": "filelock", "filelock": "filelock",
"flake8": "flake8>=3.8.3", "flake8": "flake8>=3.8.3",
...@@ -15,7 +15,7 @@ deps = { ...@@ -15,7 +15,7 @@ deps = {
"isort": "isort>=5.5.4", "isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2", "jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65", "jaxlib": "jaxlib>=0.1.65",
"k-diffusion": "k-diffusion", "k-diffusion": "k-diffusion>=0.0.12",
"librosa": "librosa", "librosa": "librosa",
"modelcards": "modelcards>=0.1.4", "modelcards": "modelcards>=0.1.4",
"numpy": "numpy", "numpy": "numpy",
......
...@@ -11,6 +11,7 @@ from ...utils import ( ...@@ -11,6 +11,7 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
is_flax_available, is_flax_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
...@@ -64,7 +65,7 @@ else: ...@@ -64,7 +65,7 @@ else:
try: try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): if not (is_torch_available() and is_transformers_available() and is_k_diffusion_version(">=", "0.0.12")):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
......
...@@ -47,6 +47,7 @@ from .import_utils import ( ...@@ -47,6 +47,7 @@ from .import_utils import (
is_flax_available, is_flax_available,
is_inflect_available, is_inflect_available,
is_k_diffusion_available, is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available, is_librosa_available,
is_modelcards_available, is_modelcards_available,
is_onnx_available, is_onnx_available,
......
...@@ -427,12 +427,26 @@ def is_transformers_version(operation: str, version: str): ...@@ -427,12 +427,26 @@ def is_transformers_version(operation: str, version: str):
operation (`str`): operation (`str`):
A string representation of an operator, such as `">"` or `"<="` A string representation of an operator, such as `">"` or `"<="`
version (`str`): version (`str`):
A string version of PyTorch A version string
""" """
if not _transformers_available: if not _transformers_available:
return False return False
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
def is_k_diffusion_version(operation: str, version: str):
"""
Args:
Compares the current k-diffusion version to a given reference with an operation.
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _k_diffusion_available:
return False
return compare_versions(parse(_k_diffusion_version), operation, version)
class OptionalDependencyNotAvailable(BaseException): class OptionalDependencyNotAvailable(BaseException):
"""An error indicating that an optional dependency of Diffusers was not found in the environment.""" """An error indicating that an optional dependency of Diffusers was not found in the environment."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment