setup.py 2.88 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import subprocess

from config import kernels, sources, target
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

target = target.lower()

# Package metadata
PACKAGE_NAME = "st_attn"
VERSION = "0.0.4"
AUTHOR = "Hao AI Lab"
DESCRIPTION = "Sliding Tile Atteniton Kernel Used in FastVideo"
URL = "https://github.com/hao-ai-lab/FastVideo/tree/main/csrc/sliding_tile_attention"

# Set environment variables
tk_root = os.getenv('THUNDERKITTENS_ROOT', os.path.abspath(os.path.join(os.getcwd(), 'tk/')))
python_include = subprocess.check_output(['python', '-c',
                                          "import sysconfig; print(sysconfig.get_path('include'))"]).decode().strip()
torch_include = subprocess.check_output([
    'python', '-c',
    "import torch; from torch.utils.cpp_extension import include_paths; print(' '.join(['-I' + p for p in include_paths()]))"
]).decode().strip()
print('st_attn root:', tk_root)
print('Python include:', python_include)
print('Torch include directories:', torch_include)

# CUDA flags
cuda_flags = [
    '-DNDEBUG', '-Xcompiler=-Wno-psabi', '-Xcompiler=-fno-strict-aliasing', '--expt-extended-lambda',
    '--expt-relaxed-constexpr', '-forward-unknown-to-host-compiler', '--use_fast_math', '-std=c++20', '-O3',
    '-Xnvlink=--verbose', '-Xptxas=--verbose', '-Xptxas=--warn-on-spills', f'-I{tk_root}/include',
    f'-I{tk_root}/prototype', f'-I{python_include}', '-DTORCH_COMPILE'
] + torch_include.split()
cpp_flags = ['-std=c++20', '-O3']

if target == 'h100':
    cuda_flags.append('-DKITTENS_HOPPER')
    cuda_flags.append('-arch=sm_90a')
else:
    raise ValueError(f'Target {target} not supported')

source_files = ['st_attn.cpp']
for k in kernels:
    if target not in sources[k]['source_files']:
        raise KeyError(f'Target {target} not found in source files for kernel {k}')
    if isinstance(sources[k]['source_files'][target], list):
        source_files.extend(sources[k]['source_files'][target])
    else:
        source_files.append(sources[k]['source_files'][target])
    cpp_flags.append(f'-DTK_COMPILE_{k.replace(" ", "_").upper()}')

setup(name=PACKAGE_NAME,
      version=VERSION,
      author=AUTHOR,
      description=DESCRIPTION,
      url=URL,
      packages=find_packages(),
      ext_modules=[
          CUDAExtension('st_attn_cuda',
                        sources=source_files,
                        extra_compile_args={
                            'cxx': cpp_flags,
                            'nvcc': cuda_flags
                        },
                        libraries=['cuda'])
      ],
      cmdclass={'build_ext': BuildExtension},
      classifiers=[
          "Programming Language :: Python :: 3",
          "Environment :: GPU :: NVIDIA CUDA :: 12",
          "License :: OSI Approved :: Apache Software License",
      ],
      python_requires='>=3.10',
      install_requires=["torch>=2.5.0"])