"...git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "0166515cff70ed673b131438a7fc92fcfd08a19e"
__init__.py 4.4 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6

7
import sys
8
9
10
import glob
import sysconfig
import subprocess
11
12
13
import ctypes
import os
import platform
14
15
16
from pathlib import Path

import transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
17
18


19
20
21
22
23
24
25
26
27
28
def is_package_installed(package):
    """Checks if a pip package is installed."""
    return (
        subprocess.run(
            [sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
        ).returncode
        == 0
    )


Przemek Tredak's avatar
Przemek Tredak committed
29
def get_te_path():
30
    """Find Transformer Engine install path using pip"""
31
    return Path(transformer_engine.__path__[0]).parent
Przemek Tredak's avatar
Przemek Tredak committed
32
33


34
def _get_sys_extension():
Przemek Tredak's avatar
Przemek Tredak committed
35
36
37
38
39
40
41
42
    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
43
        raise RuntimeError(f"Unsupported operating system ({system})")
Przemek Tredak's avatar
Przemek Tredak committed
44

45
46
47
    return extension


48
49
def _load_cudnn():
    """Load CUDNN shared library."""
50
    # Attempt to locate cuDNN in Python dist-packages
51
52
53
54
55
56
57
58
59
60
61
62
    lib_path = glob.glob(
        os.path.join(
            sysconfig.get_path("purelib"),
            f"nvidia/cudnn/lib/libcudnn.{_get_sys_extension()}.*[0-9]",
        )
    )
    if lib_path:
        assert (
            len(lib_path) == 1
        ), f"Found {len(lib_path)} libcudnn.{_get_sys_extension()}.x in nvidia-cudnn-cuXX."
        return ctypes.CDLL(lib_path[0], mode=ctypes.RTLD_GLOBAL)

63
    # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
64
65
66
67
68
69
70
    cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
    if cudnn_home:
        libs = glob.glob(f"{cudnn_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
        libs.sort(reverse=True, key=os.path.basename)
        if libs:
            return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

71
72
73
74
75
76
    # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
    libs = glob.glob(f"{cuda_home}/**/libcudnn.{_get_sys_extension()}*", recursive=True)
    libs.sort(reverse=True, key=os.path.basename)
    if libs:
        return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
77

78
    # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
79
80
81
    return ctypes.CDLL(f"libcudnn.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


82
83
84
85
def _load_library():
    """Load shared library with Transformer Engine C extensions"""

    so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
86
87
88
89
90
91
92
    if not so_path.exists():
        so_path = (
            get_te_path()
            / "transformer_engine"
            / "wheel_lib"
            / f"libtransformer_engine.{_get_sys_extension()}"
        )
93
94
95
96
97
    if not so_path.exists():
        so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
    assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"

    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
Przemek Tredak's avatar
Przemek Tredak committed
98
99


100
101
def _load_nvrtc():
    """Load NVRTC shared library."""
102
103
104
105
106
107
108
109
110
    # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
    libs = glob.glob(f"{cuda_home}/**/libnvrtc.{_get_sys_extension()}*", recursive=True)
    libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
    libs.sort(reverse=True, key=os.path.basename)
    if libs:
        return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

    # Attempt to locate NVRTC via ldconfig
111
112
113
114
115
116
117
118
119
120
    libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True)
    libs = libs.decode("utf-8").split("\n")
    sos = []
    for lib in libs:
        if "stub" in lib or "libnvrtc-builtins" in lib:
            continue
        if "libnvrtc" in lib and "=>" in lib:
            sos.append(lib.split(">")[1].strip())
    if sos:
        return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
121
122

    # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
123
124
125
126
127
128
    return ctypes.CDLL(f"libnvrtc.{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
    _CUDNN_LIB_CTYPES = _load_cudnn()
    _NVRTC_LIB_CTYPES = _load_nvrtc()
129
    _TE_LIB_CTYPES = _load_library()