Unverified Commit d705f7ff authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Replaced deprecated `pkg_resources` with `packaging` (#860)



replaced deprecated pkg_resources with packaging
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent f0311a18
...@@ -246,6 +246,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -246,6 +246,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
install_reqs: List[str] = [ install_reqs: List[str] = [
"pydantic", "pydantic",
"importlib-metadata>=1.0; python_version<'3.8'", "importlib-metadata>=1.0; python_version<'3.8'",
"packaging",
] ]
test_reqs: List[str] = ["pytest"] test_reqs: List[str] = ["pytest"]
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
"""Attention.""" """Attention."""
import collections import collections
from contextlib import nullcontext from contextlib import nullcontext
from importlib.metadata import version from importlib.metadata import version as get_pkg_version
import math import math
import os import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings import warnings
import numpy as np import numpy as np
from pkg_resources import packaging from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -67,13 +67,13 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo ...@@ -67,13 +67,13 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6") _flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.8") _flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1") _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
if _flash_attn_version >= _flash_attn_version_required: if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment