Commit d07683f3 authored by zhanggezhong's avatar zhanggezhong
Browse files

Update setup.py

parent f8f6f259
......@@ -23,7 +23,9 @@ import shutil
import http.client
import urllib.request
import urllib.error
import importlib
from pathlib import Path
from packaging import version
from packaging.version import parse
import torch.version
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
......@@ -328,7 +330,38 @@ class CMakeBuild(BuildExtension):
["cmake", "--build", ".", "--verbose", *build_args], cwd=build_temp, check=True
)
USE_FASTPT_CUDA = os.getenv('USE_FASTPT_CUDA', 'False').lower() == 'true'
def check_fastpt_version():
try:
# Try to import the fastpt module
fastpt = importlib.import_module('fastpt')
# Get version number
fastpt_version = getattr(fastpt, '__version__', None)
if fastpt_version is None:
raise ImportError("fastpt module doesn't have __version__ attribute, cannot determine version")
print(f"Detected fastpt installation, version: {fastpt_version}")
# Compare version numbers
if version.parse(fastpt_version) >= version.parse('2.0.2'):
print("fastpt version ≥ 2.0.2")
return True
else:
print(f"fastpt version {fastpt_version} < 2.0.2")
return False
except ImportError as e:
print(f"Error: fastpt not installed or import failed - {str(e)}")
raise
try:
if check_fastpt_version():
USE_FASTPT_CUDA = os.getenv('USE_FASTPT_CUDA', '0') == '1'
else:
USE_FASTPT_CUDA = os.getenv('USE_FASTPT_CUDA', 'False').lower() == 'true'
except Exception as e:
print(f"Program terminated: {str(e)}")
if CUDA_HOME is not None:
extra_nvcc_flags = [
'-O3',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment