setup.py 4.11 KB
Newer Older
1
import os, sys
Chenggang Zhao's avatar
Chenggang Zhao committed
2
3
import subprocess
import setuptools
4
import subprocess
5
import sysconfig
6
7
from typing import Optional
import subprocess
lijian's avatar
lijian committed
8
from setuptools import Distribution
9
10

import torch
11

12
13
14
15
pwd = os.path.dirname(os.path.abspath(__file__))
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
    add_git_version = True
Chenggang Zhao's avatar
Chenggang Zhao committed
16

17
shmem = None
lijian's avatar
lijian committed
18
build_shca = False
19
20
21
22
23
24
25
26
other = []
for arg in sys.argv:
    if arg.startswith("--shmem="):
        shmem = arg.split("=", 1)[1]
        if shmem == "rocm":
            shmem = "a"
        elif shmem == "nv":
            shmem = "b"
lijian's avatar
lijian committed
27
28
    elif arg == "--build_shca":
        build_shca = True
29
30
31
32
    else:
        other.append(arg)
sys.argv = other

33
34
35
36
37
38
39
40
41
42
def get_version_add(sha: Optional[str] = None) -> str:
    command = "git config --global --add safe.directory " + pwd
    result = subprocess.run(command, shell=True, capture_output=False, text=True)
    deepep_root = os.path.dirname(os.path.abspath(__file__))
    add_version_path = os.path.join(os.path.join(deepep_root, "deep_ep"), "version.py")
    major, minor, _ = torch.__version__.split('.')
    if add_git_version:
        if sha != 'Unknown':
            if sha is None:
                sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=deepep_root).decode('ascii').strip()
lishen's avatar
lishen committed
43
            if (major, minor) >= ('2', '4'):
44
45
                version = 'das.opt1.' + sha[:7] + shmem
    else:
lishen's avatar
lishen committed
46
        if (major, minor) >= ('2', '4'):
47
48
49
50
51
52
53
54
55
56
57
58
            version = 'das.opt1'

    if os.getenv("ROCM_PATH"):
        rocm_path = os.getenv('ROCM_PATH', "")
        rocm_version_path = os.path.join(rocm_path, '.info', "rocm_version")
        with open(rocm_version_path, 'r',encoding='utf-8') as file:
            lines = file.readlines()
        rocm_version=lines[0].replace(".", "")
        version += ".dtk" + rocm_version

    new_version_content = f"""
try:
59
60
61
    __version__ = "1.1.0"
    __version_tuple__ = (1, 1, 0)
    __hcu_version__ = f'1.1.0+{version}'
62
63
64
65
66
67
68
69
70
71
72

    from deep_ep.version import __version__, __version_tuple__, __hcu_version__
except Exception as e:
    import warnings
    warnings.warn(f"Failed to read commit hash:\\n + str(e)",
                  RuntimeWarning, stacklevel=2)
    __version__ = "dev"
    __version_tuple__ = (0, 0, __version__)

def _prev_minor_version_was(version_str):
    '''Check whether a given version matches the previous minor version.
Chenggang Zhao's avatar
Chenggang Zhao committed
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    Return True if version_str matches the previous minor version.

    For example - return True if the current version is 0.7.4 and the
    supplied version_str is '0.6'.

    Used for --show-hidden-metrics-for-version.
    '''
    # Match anything if this is a dev tree
    if __version_tuple__[0:2] == (0, 0):
        return True

    # Note - this won't do the right thing when we release 1.0!
    # assert __version_tuple__[0] == 0
    assert isinstance(__version_tuple__[1], int)
    return version_str == f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"

def _prev_minor_version():
    '''For the purpose of testing, return a previous minor version number.'''
    # In dev tree, this will return "0.-1", but that will work fine"
    assert isinstance(__version_tuple__[1], int)
    return f"{{__version_tuple__[0]}}.{{__version_tuple__[1] - 1}}"
"""

    with open(add_version_path, encoding="utf-8",mode="w") as file:
        file.write(new_version_content)
    file.close()

def get_version():
    get_version_add()
    version_file = 'deep_ep/version.py'
    with open(version_file, encoding='utf-8') as f:
        exec(compile(f.read(), version_file, 'exec'))
    return locals()['__hcu_version__']

def get_deepep_version() -> str:
    version = get_version()
    return version

lijian's avatar
lijian committed
112
113
114
115
class BinaryDistribution(Distribution):
    def has_ext_modules(self):
        return True

lijian's avatar
lijian committed
116
version_suffix = '.shca' if build_shca else ''
117

118
if __name__ == '__main__':
Chenggang Zhao's avatar
Chenggang Zhao committed
119
    setuptools.setup(
120
121
        name='deep_ep',
        version=get_deepep_version() + version_suffix,
lijian6's avatar
lijian6 committed
122
123
        packages=setuptools.find_packages(include=['deep_ep']),
        include_package_data=True,
124
        package_data={"deep_ep": [f"deep_ep_cpp{sysconfig.get_config_var('EXT_SUFFIX')}"]},
lijian6's avatar
lijian6 committed
125
        zip_safe=False,
lijian's avatar
lijian committed
126
        distclass=BinaryDistribution,
Chenggang Zhao's avatar
Chenggang Zhao committed
127
    )