__init__.py 1.1 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6
7
8
9
import ctypes
import os
import platform
import subprocess
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12


def get_te_path():
13
    """Find Transformer Engine install path using pip"""
Przemek Tredak's avatar
Przemek Tredak committed
14

15
16
17
18
    command = ["pip", "show", "transformer_engine"]
    result = subprocess.run(command, capture_output=True, check=True, text=True)
    result = result.stdout.replace("\n", ":").split(":")
    return result[result.index("Location")+1].strip()
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21


def _load_library():
22
    """Load shared library with Transformer Engine C extensions"""
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25
26
27
28
29
30
31

    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
32
        raise RuntimeError(f"Unsupported operating system ({system})")
Przemek Tredak's avatar
Przemek Tredak committed
33
34
35
36
37
38
39
40
    lib_name = "libtransformer_engine." + extension
    dll_path = get_te_path()
    dll_path = os.path.join(dll_path, lib_name)

    return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)


_TE_LIB_CTYPES = _load_library()