__init__.py 12.1 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6

7
import ctypes
8
9
10
11
import functools
import glob
import importlib
from importlib.metadata import version, metadata, PackageNotFoundError
12
import logging
13
import os
14
from pathlib import Path
15
16
17
18
19
import platform
import subprocess
import sys
import sysconfig
from typing import Optional
20

Przemek Tredak's avatar
Przemek Tredak committed
21

22
23
24
25
_logger = logging.getLogger(__name__)


@functools.lru_cache(maxsize=None)
26
def _is_pip_package_installed(package) -> bool:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    """Check if the given package is installed via pip."""

    # This is needed because we only want to return true
    # if the python package is installed via pip, and not
    # if it's importable in the current directory due to
    # the presence of the shared library module.
    try:
        metadata(package)
    except PackageNotFoundError:
        return False
    return True


@functools.lru_cache(maxsize=None)
41
def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]:
42
    """
43
44
45
46
47
48
    Find a shared object file with the given prefix within the top level TE directory.

    The following locations are searched:
        1. Top level directory (editable install).
        2. `transformer_engine` directory (source install).
        3. `wheel_lib` directory (PyPI install).
49
50
51
52
53

    Returns None if no shared object files are found.
    Raises an error if multiple shared object files are found.
    """

54
55
    # Ensure top level dir exists and has the module before searching.
    if not te_path.is_dir() or not (te_path / "transformer_engine").exists():
56
57
58
59
        return None

    files = []
    search_paths = (
60
61
62
        te_path,  # Editable build.
        te_path / "transformer_engine",  # Regular source build.
        te_path / "transformer_engine/wheel_lib",  # PyPI.
63
64
    )

65
    # Search.
66
67
68
69
70
71
    for dir_path in search_paths:
        if not dir_path.is_dir():
            continue
        for file_path in dir_path.iterdir():
            if file_path.name.startswith(prefix) and file_path.suffix == _get_sys_extension():
                files.append(file_path)
72
73
74
75
76
77
78
79
80
81
82

    if len(files) == 0:
        return None
    if len(files) == 1:
        return files[0]
    raise RuntimeError(f"Multiple files found: {files}")


@functools.lru_cache(maxsize=None)
def _get_shared_object_file(library: str) -> Path:
    """
83
84
85
86
87
88
    Path to shared object file for a Transformer Engine library.

    TE libraries are 'core', 'torch', or 'jax'. This function first
    searches in the imported TE directory, and then in the
    site-packages directory.

89
90
91
92
93
94
95
96
97
    """

    # Check provided input and determine the correct prefix for .so.
    assert library in ("core", "torch", "jax"), f"Unsupported TE library {library}."
    if library == "core":
        so_prefix = "libtransformer_engine"
    else:
        so_prefix = f"transformer_engine_{library}"

98
99
100
101
102
    # Search for shared lib in imported directory
    te_path = Path(importlib.util.find_spec("transformer_engine").origin).parent.parent
    so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
    if so_path is not None:
        return so_path
103

104
105
106
107
108
    # Search for shared lib in site-packages directory
    te_path = Path(sysconfig.get_paths()["purelib"])
    so_path = _find_shared_object_in_te_dir(te_path, so_prefix)
    if so_path is not None:
        return so_path
109

110
111
112
    raise FileNotFoundError(
        f"Could not find shared object file for Transformer Engine {library} lib."
    )
Przemek Tredak's avatar
Przemek Tredak committed
113
114


115
@functools.lru_cache(maxsize=None)
116
def load_framework_extension(framework: str) -> None:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    """
    Load shared library with Transformer Engine framework bindings
    and check verify correctness if installed via PyPI.
    """

    # Supported frameworks.
    assert framework in ("jax", "torch"), f"Unsupported framework {framework}"

    # Name of the framework extension library.
    module_name = f"transformer_engine_{framework}"

    # Name of the pip extra dependency for framework extensions from PyPI.
    extra_dep_name = module_name
    if framework == "torch":
        extra_dep_name = "pytorch"

    # If the framework extension pip package is installed, it means that TE is installed via
    # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
    # extension are all installed via PyPI and have matching version.
    if _is_pip_package_installed(module_name):
        assert _is_pip_package_installed(
            "transformer_engine"
        ), "Could not find `transformer-engine`."
        assert _is_pip_package_installed(
            "transformer_engine_cu12"
        ), "Could not find `transformer-engine-cu12`."
        assert (
            version(module_name)
            == version("transformer-engine")
            == version("transformer-engine-cu12")
        ), (
            "TransformerEngine package version mismatch. Found"
            f" {module_name} v{version(module_name)}, transformer-engine"
            f" v{version('transformer-engine')}, and transformer-engine-cu12"
            f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
            f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
        )

    # If the core package is installed via PyPI, log if
    # the framework extension is not found from PyPI.
    # Note: Should we error? This is a rare use case.
    if _is_pip_package_installed("transformer-engine-cu12"):
        if not _is_pip_package_installed(module_name):
            _logger.info(
                "Could not find package %s. Install transformer-engine using "
                f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'",
                module_name,
            )

    # After all checks are completed, load the shared object file.
    spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework))
    solib = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = solib
    spec.loader.exec_module(solib)


@functools.lru_cache(maxsize=None)
174
175
def _get_sys_extension() -> str:
    """File extension for shared objects."""
Przemek Tredak's avatar
Przemek Tredak committed
176
    system = platform.system()
177

Przemek Tredak's avatar
Przemek Tredak committed
178
    if system == "Linux":
179
180
181
182
183
184
        return ".so"
    if system == "Darwin":
        return ".dylib"
    if system == "Windows":
        return ".dll"
    raise RuntimeError(f"Unsupported operating system ({system})")
185
186


187
@functools.lru_cache(maxsize=None)
188
189
190
191
192
193
194
195
def _load_nvidia_cuda_library(lib_name: str):
    """
    Attempts to load shared object file installed via pip.

    `lib_name`: Name of package as found in the `nvidia` dir in python environment.
    """

    so_paths = glob.glob(
196
197
        os.path.join(
            sysconfig.get_path("purelib"),
198
            f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
199
200
        )
    )
201
202
203
204
205
206
207
208
209
210
211
212

    path_found = len(so_paths) > 0
    ctypes_handles = []

    if path_found:
        for so_path in so_paths:
            ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))

    return path_found, ctypes_handles


@functools.lru_cache(maxsize=None)
213
def _nvidia_cudart_include_dir() -> str:
214
215
216
217
218
219
220
    """Returns the include directory for cuda_runtime.h if exists in python environment."""

    try:
        import nvidia
    except ModuleNotFoundError:
        return ""

221
222
223
224
225
    # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
    # above doesn't through. However, they don't set "__file__" attribute.
    if nvidia.__file__ is None:
        return ""

226
227
228
229
    include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
    return str(include_dir) if include_dir.exists() else ""


230
@functools.lru_cache(maxsize=None)
231
232
def _load_cudnn():
    """Load CUDNN shared library."""
233

234
    # Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
235
236
    cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
    if cudnn_home:
237
        libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
238
239
240
241
        libs.sort(reverse=True, key=os.path.basename)
        if libs:
            return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

242
243
    # Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
244
    libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
245
246
247
    libs.sort(reverse=True, key=os.path.basename)
    if libs:
        return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
248

249
250
251
252
253
    # Attempt to locate cuDNN in Python dist-packages
    found, handle = _load_nvidia_cuda_library("cudnn")
    if found:
        return handle

254
    # Attempt to locate libcudnn via ldconfig
255
    libs = subprocess.check_output(["ldconfig", "-p"])
256
257
258
259
260
261
262
263
    libs = libs.decode("utf-8").split("\n")
    sos = []
    for lib in libs:
        if "libcudnn" in lib and "=>" in lib:
            sos.append(lib.split(">")[1].strip())
    if sos:
        return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)

264
    # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
265
    return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
266
267


268
@functools.lru_cache(maxsize=None)
269
270
def _load_nvrtc():
    """Load NVRTC shared library."""
271
272
    # Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
273
    libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
274
275
276
277
278
    libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
    libs.sort(reverse=True, key=os.path.basename)
    if libs:
        return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

279
280
281
282
283
    # Attempt to locate NVRTC in Python dist-packages
    found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
    if found:
        return handle

284
    # Attempt to locate NVRTC via ldconfig
285
    libs = subprocess.check_output(["ldconfig", "-p"])
286
287
288
289
290
291
292
    libs = libs.decode("utf-8").split("\n")
    sos = []
    for lib in libs:
        if "libnvrtc" in lib and "=>" in lib:
            sos.append(lib.split(">")[1].strip())
    if sos:
        return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
293
294

    # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
295
    return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
296
297


vasunvidia's avatar
vasunvidia committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
@functools.lru_cache(maxsize=None)
def _load_curand():
    """Load cuRAND shared library."""
    # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
    libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
    libs = list(filter(lambda x: not ("stub" in x), libs))
    libs.sort(reverse=True, key=os.path.basename)
    if libs:
        return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)

    # Attempt to locate cuRAND in Python dist-packages
    found, handle = _load_nvidia_cuda_library("curand")
    if found:
        return handle

    # Attempt to locate cuRAND via ldconfig
315
    libs = subprocess.check_output(["ldconfig", "-p"])
vasunvidia's avatar
vasunvidia committed
316
317
318
319
320
321
322
323
324
325
326
327
    libs = libs.decode("utf-8").split("\n")
    sos = []
    for lib in libs:
        if "libcurand" in lib and "=>" in lib:
            sos.append(lib.split(">")[1].strip())
    if sos:
        return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)

    # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
    return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)


328
329
330
331
332
333
@functools.lru_cache(maxsize=None)
def _load_core_library():
    """Load shared library with Transformer Engine C extensions"""
    return ctypes.CDLL(_get_shared_object_file("core"), mode=ctypes.RTLD_GLOBAL)


334
335
336
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
    _CUDNN_LIB_CTYPES = _load_cudnn()
    _NVRTC_LIB_CTYPES = _load_nvrtc()
vasunvidia's avatar
vasunvidia committed
337
    _CURAND_LIB_CTYPES = _load_curand()
338
339
    _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
    _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
340
    _TE_LIB_CTYPES = _load_core_library()
341
342
343
344

    # Needed to find the correct headers for NVRTC kernels.
    if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
        os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()