import os
from pathlib import Path

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


ROOT_DIR = Path(__file__).parent.resolve()


def get_extensions():
    extra_compile_args = {
        "cxx": ["-O3", "-w"],
        "nvcc": [
            "-O3",
            "-w",
            "-mllvm",
            "-enable-num-vgprs-512=true",
            "-DHIP_ENABLE_WARP_SYNC_BUILTINS",
        ],
    }

    sources = [
        str(ROOT_DIR / "csrc/export.cpp"),
        str(ROOT_DIR / "csrc/fuse_rms_roped.cu"),
    ]

    include_dirs = [str(ROOT_DIR / "csrc")]

    extension = CUDAExtension(
        name="lmcustomop.op",
        sources=sources,
        include_dirs=include_dirs,
        extra_compile_args=extra_compile_args,
    )

    return [extension]


setup(
    name="lmcustomop",
    version=os.getenv("LMCUSTOMOP_VERSION", "0.0.1"),
    description="Minimal lmcustomop package",
    packages=["lmcustomop"],
    package_dir={"lmcustomop": "."},
    ext_modules=get_extensions(),
    cmdclass={"build_ext": BuildExtension},
    zip_safe=False,
    install_requires=["torch"],
)
