liboperators.py 3.03 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
import os
import platform
import ctypes
PanZezhong's avatar
PanZezhong committed
4
from ctypes import c_int, c_int64, c_uint64, Structure, POINTER
PanZezhongQY's avatar
PanZezhongQY committed
5
6
from .datatypes import *
from .devices import *
7
from pathlib import Path
PanZezhongQY's avatar
PanZezhongQY committed
8
9
10
11

Device = c_int
Optype = c_int

12
INFINI_ROOT = os.getenv("INFINI_ROOT") or str(Path.home() / ".infini")
PanZezhongQY's avatar
PanZezhongQY committed
13
14
15


class TensorDescriptor(Structure):
PanZezhong's avatar
PanZezhong committed
16
    _fields_ = []
PanZezhongQY's avatar
PanZezhongQY committed
17
18
19
20
21
22


infiniopTensorDescriptor_t = ctypes.POINTER(TensorDescriptor)


class CTensor:
23
    def __init__(self, desc, torch_tensor):
PanZezhongQY's avatar
PanZezhongQY committed
24
        self.descriptor = desc
25
26
        self.torch_tensor_ = torch_tensor
        self.data = torch_tensor.data_ptr()
27

PanZezhong's avatar
PanZezhong committed
28
29
30
    def destroyDesc(self, lib_):
        lib_.infiniopDestroyTensorDescriptor(self.descriptor)
        self.descriptor = None
PanZezhongQY's avatar
PanZezhongQY committed
31
32
33
34
35
36
37
38
39


class Handle(Structure):
    _fields_ = [("device", c_int), ("device_id", c_int)]


infiniopHandle_t = POINTER(Handle)


40
41
42
43
44
45
46
47
48
49
50
51
52
53
class InfiniLib:
    def __init__(self, librt, libop):
        self.librt = librt
        self.libop = libop

    def __getattr__(self, name):
        if hasattr(self.libop, name):
            return getattr(self.libop, name)
        elif hasattr(self.librt, name):
            return getattr(self.librt, name)
        else:
            raise AttributeError(f"Attribute {name} not found in library")


PanZezhongQY's avatar
PanZezhongQY committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Open operators library
def open_lib():
    def find_library_in_ld_path(subdir, library_name):
        ld_library_path = os.path.join(INFINI_ROOT, subdir)
        paths = ld_library_path.split(os.pathsep)
        for path in paths:
            full_path = os.path.join(path, library_name)
            if os.path.isfile(full_path):
                return full_path
        return None

    system_name = platform.system()
    # Load the library
    if system_name == "Windows":
68
69
        libop_path = find_library_in_ld_path("bin", "infiniop.dll")
        librt_path = find_library_in_ld_path("bin", "infinirt.dll")
PanZezhongQY's avatar
PanZezhongQY committed
70
    elif system_name == "Linux":
71
72
        libop_path = find_library_in_ld_path("lib", "libinfiniop.so")
        librt_path = find_library_in_ld_path("lib", "libinfinirt.so")
PanZezhongQY's avatar
PanZezhongQY committed
73
74

    assert (
75
        libop_path is not None
PanZezhongQY's avatar
PanZezhongQY committed
76
    ), f"Cannot find infiniop.dll or libinfiniop.so. Check if INFINI_ROOT is set correctly."
77
78
79
80
81
82
83
    assert (
        librt_path is not None
    ), f"Cannot find infinirt.dll or libinfinirt.so. Check if INFINI_ROOT is set correctly."

    librt = ctypes.CDLL(librt_path)
    libop = ctypes.CDLL(libop_path)
    lib = InfiniLib(librt, libop)
PanZezhongQY's avatar
PanZezhongQY committed
84
85
86
87
88
89
90
    lib.infiniopCreateTensorDescriptor.argtypes = [
        POINTER(infiniopTensorDescriptor_t),
        c_uint64,
        POINTER(c_uint64),
        POINTER(c_int64),
        c_int,
    ]
PanZezhong's avatar
PanZezhong committed
91
92
93
    lib.infiniopCreateTensorDescriptor.restype = c_int
    lib.infiniopDestroyTensorDescriptor.argtypes = [infiniopTensorDescriptor_t]
    lib.infiniopDestroyTensorDescriptor.restype = c_int
94
    lib.infiniopCreateHandle.argtypes = [POINTER(infiniopHandle_t)]
PanZezhongQY's avatar
PanZezhongQY committed
95
96
97
    lib.infiniopCreateHandle.restype = c_int
    lib.infiniopDestroyHandle.argtypes = [infiniopHandle_t]
    lib.infiniopDestroyHandle.restype = c_int
98
99
    lib.infinirtSetDevice.argtypes = [c_int, c_int]
    lib.infinirtSetDevice.restype = c_int
PanZezhongQY's avatar
PanZezhongQY committed
100
101

    return lib