"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "82bc797f17eee5f830edcd058e79390a0c5acff6"
Unverified Commit d705f7ff authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Replaced deprecated `pkg_resources` with `packaging` (#860)



replaced deprecated pkg_resources with packaging
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent f0311a18
......@@ -246,6 +246,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0; python_version<'3.8'",
"packaging",
]
test_reqs: List[str] = ["pytest"]
......
......@@ -5,14 +5,14 @@
"""Attention."""
import collections
from contextlib import nullcontext
from importlib.metadata import version
from importlib.metadata import version as get_pkg_version
import math
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import warnings
import numpy as np
from pkg_resources import packaging
from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
......@@ -67,13 +67,13 @@ from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= packaging.version.Version("2.4.1")
_flash_attn_version = PkgVersion(get_pkg_version("flash-attn"))
_flash_attn_version_required = PkgVersion("2.0.6")
_flash_attn_max_version = PkgVersion("2.5.8")
_flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4")
_flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1")
if _flash_attn_version >= _flash_attn_version_required:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_forward_func # pylint: disable=no-name-in-module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment