from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension import torch # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17","--gpu-max-threads-per-block=1024"] ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] extra_compile_args={ "cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS, } if not torch.version.hip: extra_compile_args.append("-arch=compute_80") setup( name="custom_kernels", ext_modules=[ CUDAExtension( name="custom_kernels.fused_bloom_attention_cuda", sources=["custom_kernels/fused_bloom_attention_cuda.cu"], extra_compile_args=extra_compile_args, ), CUDAExtension( name="custom_kernels.fused_attention_cuda", sources=["custom_kernels/fused_attention_cuda.cu"], extra_compile_args=extra_compile_args, ), ], cmdclass={"build_ext": BuildExtension}, )