from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from setuptools import find_packages, setup

def get_version():
    return "0.1"

def get_extensions():
    extensions = []
    include_dirs = []
    define_macros = []
    extra_compile_args = {'cxx': ['-lrocblas', '-lGlm']}
    op_files = glob.glob('./glmop/*.cu') + \
        glob.glob('./glmop/*.cc')
    extension = CUDAExtension
    # include_dirs.append(os.path.abspath('../'))
    # include_dirs.append(os.path.abspath('/opt/dtk-23.04/cuda/include/'))
    ext_ops = extension(
            name="th_glm",
            sources=op_files,
            include_dirs=include_dirs,
            #extra_link_args=['-L/parastor/home/zhouxiang/GLM-130B/FasterTransformer/build2/lib'],
            define_macros=define_macros,
            extra_compile_args=extra_compile_args)
    extensions.append(ext_ops)
    return extensions

setup(
    name='th_glm',
    version=get_version(),
    description='Torch th_glm Computer Vision Foundation',
    keywords='computer vision',
    packages=find_packages(),
    include_package_data=False,
    package_data = {
        'th_glm':[
           "glmop/*.cu",
           "glmop/*.cc"
           "glmop/*.h",
        ]
    },
    ext_modules=get_extensions(),
    cmdclass={
        'build_ext': BuildExtension
    },
    zip_safe=False
)
