__init__.py 1.22 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""FW agnostic user-end APIs"""
6

7
8
9
import ctypes
import os
import platform
10
11
12
from pathlib import Path

import transformer_engine
Przemek Tredak's avatar
Przemek Tredak committed
13
14
15


def get_te_path():
16
    """Find Transformer Engine install path using pip"""
17
    return Path(transformer_engine.__path__[0]).parent
Przemek Tredak's avatar
Przemek Tredak committed
18
19


20
def _get_sys_extension():
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
25
26
27
28
    system = platform.system()
    if system == "Linux":
        extension = "so"
    elif system == "Darwin":
        extension = "dylib"
    elif system == "Windows":
        extension = "dll"
    else:
29
        raise RuntimeError(f"Unsupported operating system ({system})")
Przemek Tredak's avatar
Przemek Tredak committed
30

31
32
33
34
35
36
37
38
39
40
41
42
    return extension


def _load_library():
    """Load shared library with Transformer Engine C extensions"""

    so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}"
    if not so_path.exists():
        so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}"
    assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}"

    return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
Przemek Tredak's avatar
Przemek Tredak committed
43
44


45
46
if "NVTE_PROJECT_BUILDING" not in os.environ:
    _TE_LIB_CTYPES = _load_library()