"plugins/rpmd/vscode:/vscode.git/clone" did not exist on "38beeefe8d134e8dbf8c37f0a2814eba92627a3a"
setup.py 3.97 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import subprocess
from packaging.version import parse, Version
from typing import Optional
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
from get_version import get_version
this_dir = os.path.dirname(os.path.abspath(__file__))
def _find_rocm_home() -> Optional[str]:
    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
    if rocm_home is None:
        try:
            pipe_hipcc = subprocess.Popen(
                ["which hipcc | xargs readlink -f"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
            hipcc, _ = pipe_hipcc.communicate()
            rocm_home = os.path.dirname(os.path.dirname(hipcc.decode(*()).rstrip('\r\n')))
            if os.path.basename(rocm_home) == 'hip':
                rocm_home = os.path.dirname(rocm_home)
        except Exception:
            rocm_home = '/opt/rocm'
            if not os.path.exists(rocm_home):
                rocm_home = None
    if rocm_home and torch.version.hip is None:
        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
    return rocm_home

ROCM_HOME = _find_rocm_home()
cc_flag = ["--offload-arch=gfx906","--offload-arch=gfx926","--offload-arch=gfx928","--offload-arch=gfx936"]
# cc_flag = ["--offload-arch=gfx906"]
ext_modules = []

ext_modules.append(
    CUDAExtension(
        name="dropout_layer_norm",
        sources=[
            "ln_api.cu",
            "ln_fwd_256.cu",
            "ln_bwd_256.cu",
            "ln_fwd_512.cu",
            "ln_bwd_512.cu",
            "ln_fwd_768.cu",
            "ln_bwd_768.cu",
            "ln_fwd_1024.cu",
            "ln_bwd_1024.cu",
            "ln_fwd_1280.cu",
            "ln_bwd_1280.cu",
            "ln_fwd_1536.cu",
            "ln_bwd_1536.cu",
            "ln_fwd_2048.cu",
            "ln_bwd_2048.cu",
            "ln_fwd_2560.cu",
            "ln_bwd_2560.cu",
            "ln_fwd_3072.cu",
            "ln_bwd_3072.cu",
            "ln_fwd_4096.cu",
            "ln_bwd_4096.cu",
            "ln_fwd_5120.cu",
            "ln_bwd_5120.cu",
            "ln_fwd_6144.cu",
            "ln_bwd_6144.cu",
            "ln_fwd_7168.cu",
            "ln_bwd_7168.cu",
            "ln_fwd_8192.cu",
            "ln_bwd_8192.cu",
            "ln_parallel_fwd_256.cu",
            "ln_parallel_bwd_256.cu",
            "ln_parallel_fwd_512.cu",
            "ln_parallel_bwd_512.cu",
            "ln_parallel_fwd_768.cu",
            "ln_parallel_bwd_768.cu",
            "ln_parallel_fwd_1024.cu",
            "ln_parallel_bwd_1024.cu",
            "ln_parallel_fwd_1280.cu",
            "ln_parallel_bwd_1280.cu",
            "ln_parallel_fwd_1536.cu",
            "ln_parallel_bwd_1536.cu",
            "ln_parallel_fwd_2048.cu",
            "ln_parallel_bwd_2048.cu",
            "ln_parallel_fwd_2560.cu",
            "ln_parallel_bwd_2560.cu",
            "ln_parallel_fwd_3072.cu",
            "ln_parallel_bwd_3072.cu",
            "ln_parallel_fwd_4096.cu",
            "ln_parallel_bwd_4096.cu",
            "ln_parallel_fwd_5120.cu",
            "ln_parallel_bwd_5120.cu",
            "ln_parallel_fwd_6144.cu",
            "ln_parallel_bwd_6144.cu",
            "ln_parallel_fwd_7168.cu",
            "ln_parallel_bwd_7168.cu",
            "ln_parallel_fwd_8192.cu",
            "ln_parallel_bwd_8192.cu",
        ],
        extra_compile_args={
            "cxx": ["-O3","-w"] ,
            "nvcc": [ "-O3","-w",'-U__HIP_NO_HALF_OPERATORS__','-U__HIP_NO_HALF_CONVERSIONS__'] + cc_flag
        },
        include_dirs=[this_dir],
    )
)

def _get_pytorch_version():
    if "PYTORCH_VERSION" in os.environ:
        return f"{os.environ['PYTORCH_VERSION']}"
    return torch.__version__

setup(
    name="dropout_layer_norm",
    version=get_version(ROCM_HOME),
    description="Fused dropout + add + layer norm",
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension} if ext_modules else {},
    install_requires=[
         f"torch=={_get_pytorch_version()}",
    ],
)