__init__.py 1.74 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6
7
8
9
import ctypes
import os
import platform
import subprocess
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12


def get_te_path():
13
    """Find Transformer Engine install path using pip"""
Przemek Tredak's avatar
Przemek Tredak committed
14

15
16
17
18
    command = ["pip", "show", "transformer_engine"]
    result = subprocess.run(command, capture_output=True, check=True, text=True)
    result = result.stdout.replace("\n", ":").split(":")
    return result[result.index("Location")+1].strip()
Przemek Tredak's avatar
Przemek Tredak committed
19
20
21


def _load_library():
22
    """Load shared library with Transformer Engine C extensions"""
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25
26
27
28
29
30
31

    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
32
        raise RuntimeError(f"Unsupported operating system ({system})")
Przemek Tredak's avatar
Przemek Tredak committed
33
34
35
36
37
38
39
    lib_name = "libtransformer_engine." + extension
    dll_path = get_te_path()
    dll_path = os.path.join(dll_path, lib_name)

    return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def _load_mpi():
    """Load MPI shared library"""

    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
        raise RuntimeError(f"Unsupported operating system ({system})")
    lib_name = "libmpi." + extension
    MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
    NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
    dll_path = os.path.join(MPI_HOME, "lib", lib_name)

    if NVTE_MPI_FOUND:
        return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
    return None


_TE_LIB_CTYPES = _load_mpi()
Przemek Tredak's avatar
Przemek Tredak committed
63
_TE_LIB_CTYPES = _load_library()