Unverified Commit f62407f7 authored by bhack's avatar bhack Committed by GitHub
Browse files

Extend import utils to cover "editable" torch versions (#29000)



* Extend import utils to cover "editable" torch versions

* Re-add type hint

* Remove whitespaces

* Double quote strings

* Update comment
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>

* Restore package_exists

* Revert "Restore package_exists"

This reverts commit 66fd2cd5c33d1b9a26a8f3e8adef2e6ec1214868.

---------
Co-authored-by: default avatarYih-Dar <2521628+ydshieh@users.noreply.github.com>
parent 56b64bf1
......@@ -39,16 +39,32 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better.
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
# Check if the package spec exists and grab its version to avoid importing a local directory
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
# Primary method to get the package version
package_version = importlib.metadata.version(pkg_name)
package_exists = True
except importlib.metadata.PackageNotFoundError:
# Fallback method: Only for "torch" and versions containing "dev"
if pkg_name == "torch":
try:
package = importlib.import_module(pkg_name)
temp_version = getattr(package, "__version__", "N/A")
# Check if the version contains "dev"
if "dev" in temp_version:
package_version = temp_version
package_exists = True
else:
package_exists = False
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
logger.debug(f"Detected {pkg_name} version {package_version}")
logger.debug(f"Detected {pkg_name} version: {package_version}")
if return_version:
return package_exists, package_version
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment