Unverified Commit 36e0ac56 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Remove use of distutils (#186)



Remove distutils
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1bc86400
......@@ -10,10 +10,10 @@ import io
import re
import copy
import tempfile
from packaging.version import Version
from setuptools import setup, find_packages, Extension
from setuptools.command.build_ext import build_ext
from distutils.version import LooseVersion
from distutils.file_util import copy_file
from shutil import copyfile
path = os.path.dirname(os.path.realpath(__file__))
......@@ -172,7 +172,7 @@ class PyTorchBuilder(FrameworkBuilderBase):
@staticmethod
def install_requires():
return ["flash-attn>=1.0.2",]
return ["flash-attn>=1.0.2", "packaging"]
class TensorFlowBuilder(FrameworkBuilderBase):
......@@ -244,13 +244,13 @@ def get_cmake_bin():
try:
out = subprocess.check_output([cmake_bin, "--version"])
except OSError:
cmake_installed_version = LooseVersion("0.0")
cmake_installed_version = Version("0.0")
else:
cmake_installed_version = LooseVersion(
cmake_installed_version = Version(
re.search(r"version\s*([\d.]+)", out.decode()).group(1)
)
if cmake_installed_version < LooseVersion("3.18.0"):
if cmake_installed_version < Version("3.18.0"):
print(
"Could not find a recent CMake to build Transformer Engine. "
"Attempting to install CMake 3.18 to a temporary location via pip.",
......@@ -399,12 +399,7 @@ class TEBuildExtension(build_ext, object):
# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
# used.
copy_file(
src_filename,
dest_filename,
verbose=self.verbose,
dry_run=self.dry_run,
)
copyfile(src_filename, dest_filename)
def get_outputs(self):
return self.all_outputs
......
......@@ -7,9 +7,9 @@ import os
import math
import warnings
from importlib.metadata import version
from distutils.version import LooseVersion
from contextlib import nullcontext
from typing import Any, Callable, Optional, Tuple, Union
from packaging.version import Version
import torch
......@@ -45,8 +45,8 @@ from transformer_engine.pytorch.distributed import (
)
from transformer_engine.pytorch.export import is_in_onnx_export_mode
_flash_attn_version = LooseVersion(version("flash-attn"))
_flash_attn_version_required = LooseVersion("1.0.2")
_flash_attn_version = Version(version("flash-attn"))
_flash_attn_version_required = Version("1.0.2")
warnings.filterwarnings("module", category=DeprecationWarning, module="transformer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment