"vscode:/vscode.git/clone" did not exist on "bc3eaac2b82a8464730b9ba4d6b2fbebf19fa314"
setup.py 751 Bytes
Newer Older
1
2
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
fxmarty's avatar
fxmarty committed
3
4
5
6
7
import torch

extra_compile_args = ["-std=c++17"]
if not torch.version.hip:
    extra_compile_args.append("-arch=compute_80")
8
9
10
11
12
13
14

setup(
    name="custom_kernels",
    ext_modules=[
        CUDAExtension(
            name="custom_kernels.fused_bloom_attention_cuda",
            sources=["custom_kernels/fused_bloom_attention_cuda.cu"],
fxmarty's avatar
fxmarty committed
15
            extra_compile_args=extra_compile_args,
16
17
18
19
        ),
        CUDAExtension(
            name="custom_kernels.fused_attention_cuda",
            sources=["custom_kernels/fused_attention_cuda.cu"],
fxmarty's avatar
fxmarty committed
20
            extra_compile_args=extra_compile_args,
21
22
23
24
        ),
    ],
    cmdclass={"build_ext": BuildExtension},
)