setup.py 3.67 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# setup.py

import subprocess

from typing import Literal

from setuptools import setup


def get_default_dependencies():
    """Determine the appropriate dependencies based on detected hardware."""
    platform = get_platform()

    if platform in ["cuda", "cpu"]:
        return [
            "torch>=2.1.2",
            "triton>=2.3.1",
        ]
    elif platform == "rocm":
        return [
            "triton>=3.0.0",
        ]
    elif platform == "xpu":
        return [
            "torch>=2.6.0",
        ]
    # TODO: Currently, triton-ascend is not compatible with torch 2.7.1. We will upgrade it later.
    elif platform == "npu":
        return ["torch==2.6.0", "torch_npu==2.6.0", "triton-ascend"]


def get_optional_dependencies():
    """Get optional dependency groups."""
    return {
        "dev": [
            "transformers>=4.52.0",
            "matplotlib>=3.7.2",
            "ruff>=0.12.0",
            "pytest>=7.1.2",
            "pytest-xdist",
            "pytest-cov",
            "pytest-asyncio",
            "pytest-rerunfailures",
            "datasets>=2.19.2",
            "seaborn",
            "mkdocs-material",
            "torchvision>=0.20",
            "prek>=0.2.28",
        ]
    }


def is_xpu_available():
    """
    Check if Intel XPU is available.
    xpu-smi is often missing right now.
    """
    try:
        subprocess.run(["xpu-smi"], check=True)
        return True
    except (subprocess.SubprocessError, FileNotFoundError):
        pass

    try:
        result = subprocess.run("sycl-ls", check=True, capture_output=True, shell=True)
        if "level_zero:gpu" in result.stdout.decode():
            return True
    except (subprocess.SubprocessError, FileNotFoundError):
        pass

    return False


def is_ascend_available() -> bool:
    """Best-effort Ascend detection.

    Checks for common Ascend environment variables and a possible `npu-smi`
    utility if present.
    """
    try:
        subprocess.run(["npu-smi", "info"], check=True)
        return True
    except (subprocess.SubprocessError, FileNotFoundError):
        pass
    return False


def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu", "npu"]:
    """
    Detect whether the system has NVIDIA or AMD GPU without torch dependency.
    """
    # Try nvidia-smi first
    try:
        subprocess.run(["nvidia-smi"], check=True)
        print("NVIDIA GPU detected")
        return "cuda"
    except (subprocess.SubprocessError, FileNotFoundError):
        # If nvidia-smi fails, check for ROCm
        try:
            subprocess.run(["rocm-smi"], check=True)
            print("ROCm GPU detected")
            return "rocm"
        except (subprocess.SubprocessError, FileNotFoundError):
            if is_xpu_available():
                print("Intel GPU detected")
                return "xpu"
            elif is_ascend_available():
                print("Ascend NPU detected")
                return "npu"
            else:
                print("No GPU detected")
                return "cpu"


setup(
    name="liger_kernel",
    package_dir={"": "src"},
    packages=["liger_kernel"],
    install_requires=get_default_dependencies(),
    extras_require=get_optional_dependencies(),
    classifiers=[
        "Development Status :: 5 - Production/Stable",
        "Intended Audience :: Developers",
        "Intended Audience :: Education",
        "Intended Audience :: Science/Research",
        "Programming Language :: Python :: 3",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "Topic :: Software Development :: Libraries :: Python Modules",
        "License :: OSI Approved :: BSD-2-Clause Software License",
        "Operating System :: OS Independent",
    ],
)