Unverified Commit 1c151283 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[versions] handle version requirement ranges (#11110)

* handle version requirement ranges

* add mixed requirement test

* cleanup
parent 7442801d
......@@ -40,6 +40,17 @@ ops = {
}
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
if got_ver is None:
raise ValueError("got_ver is None")
if want_ver is None:
raise ValueError("want_ver is None")
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
raise ImportError(
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
)
def require_version(requirement: str, hint: Optional[str] = None) -> None:
"""
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
......@@ -51,33 +62,36 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
"""
# note: while pkg_resources.require_version(requirement) is a much simpler way to do it, it
# fails if some of the dependencies of the dependencies are not matching, which is not necessarily
# bad, hence the more complicated check - which also should be faster, since it doesn't check
# dependencies of dependencies.
hint = f"\n{hint}" if hint is not None else ""
# non-versioned check
if re.match(r"^[\w_\-\d]+$", requirement):
pkg, op, want_ver = requirement, None, None
else:
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2})(.+)", requirement)
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
if not match:
raise ValueError(
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
)
pkg, op, want_ver = match[0]
if op not in ops:
raise ValueError(f"need one of {list(ops.keys())}, but got {op}")
pkg, want_full = match[0]
want_range = want_full.split(",") # there could be multiple requirements
wanted = {}
for w in want_range:
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
if not match:
raise ValueError(
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
)
op, want_ver = match[0]
wanted[op] = want_ver
if op not in ops:
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
# special case
if pkg == "python":
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
raise ImportError(
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}."
)
for op, want_ver in wanted.items():
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
return
# check if any version is installed
......@@ -88,11 +102,10 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
)
# check that the right version is installed if version number was provided
if want_ver is not None and not ops[op](version.parse(got_ver), version.parse(want_ver)):
raise ImportError(
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
)
# check that the right version is installed if version number or a range was provided
if want_ver is not None:
for op, want_ver in wanted.items():
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
def require_version_core(requirement):
......
......@@ -14,8 +14,6 @@
import sys
import numpy
from transformers.testing_utils import TestCasePlus
from transformers.utils.versions import (
importlib_metadata,
......@@ -25,7 +23,7 @@ from transformers.utils.versions import (
)
numpy_ver = numpy.__version__
numpy_ver = importlib_metadata.version("numpy")
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
......@@ -54,6 +52,9 @@ class DependencyVersionCheckTest(TestCasePlus):
# gt
require_version_core("numpy>1.0.0")
# mix
require_version_core("numpy>1.0.0,<1000")
# requirement w/o version
require_version_core("numpy")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment