__init__.py 1.76 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6
7
8
9
import ctypes
import os
import platform
import subprocess
10
import sys
Przemek Tredak's avatar
Przemek Tredak committed
11
12
13


def get_te_path():
14
    """Find Transformer Engine install path using pip"""
Przemek Tredak's avatar
Przemek Tredak committed
15

16
    command = [sys.executable, "-m", "pip", "show", "transformer_engine"]
17
18
    result = subprocess.run(command, capture_output=True, check=True, text=True)
    result = result.stdout.replace("\n", ":").split(":")
19
    return result[result.index("Location") + 1].strip()
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22


def _load_library():
23
    """Load shared library with Transformer Engine C extensions"""
Przemek Tredak's avatar
Przemek Tredak committed
24
25
26
27
28
29
30
31
32

    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
33
        raise RuntimeError(f"Unsupported operating system ({system})")
Przemek Tredak's avatar
Przemek Tredak committed
34
35
36
37
38
39
40
    lib_name = "libtransformer_engine." + extension
    dll_path = get_te_path()
    dll_path = os.path.join(dll_path, lib_name)

    return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)


41
42
def _load_userbuffers():
    """Load shared library with userbuffers"""
43
44
45
46
47
48
49
50
51
52

    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
        raise RuntimeError(f"Unsupported operating system ({system})")
53
54
55
    lib_name = "libtransformer_engine_userbuffers." + extension
    dll_path = get_te_path()
    dll_path = os.path.join(dll_path, lib_name)
56

57
    if os.path.exists(dll_path):
58
59
60
61
        return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
    return None


Przemek Tredak's avatar
Przemek Tredak committed
62
_TE_LIB_CTYPES = _load_library()
63
_UB_LIB_CTYPES = _load_userbuffers()