utils.py 3.57 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import ctypes
from .datatypes import *
from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t


def check_error(status):
    if status != 0:
        raise Exception("Error code " + str(status))


def to_tensor(tensor, lib):
    """
    Convert a PyTorch tensor to a library Tensor(descriptor, data).
    """
    import torch

    ndim = tensor.ndimension()
    shape = (ctypes.c_size_t * ndim)(*tensor.shape)
    strides = (ctypes.c_int64 * ndim)(*(tensor.stride()))
    data_ptr = tensor.data_ptr()
    # fmt: off
    dt = (
        InfiniDtype.I8 if tensor.dtype == torch.int8 else
        InfiniDtype.I16 if tensor.dtype == torch.int16 else
        InfiniDtype.I32 if tensor.dtype == torch.int32 else
        InfiniDtype.I64 if tensor.dtype == torch.int64 else
        InfiniDtype.U8 if tensor.dtype == torch.uint8 else
        InfiniDtype.F16 if tensor.dtype == torch.float16 else
        InfiniDtype.BF16 if tensor.dtype == torch.bfloat16 else
        InfiniDtype.F32 if tensor.dtype == torch.float32 else
        InfiniDtype.F64 if tensor.dtype == torch.float64 else
        # TODO: These following types may not be supported by older 
        # versions of PyTorch.
        InfiniDtype.U16 if tensor.dtype == torch.uint16 else
        InfiniDtype.U32 if tensor.dtype == torch.uint32 else
        InfiniDtype.U64 if tensor.dtype == torch.uint64 else
        None
    )
    # fmt: on
    assert dt is not None
    # Create TensorDecriptor
    tensor_desc = infiniopTensorDescriptor_t()
    lib.infiniopCreateTensorDescriptor(
        ctypes.byref(tensor_desc), ndim, shape, strides, dt
    )
    # Create Tensor
    return CTensor(tensor_desc, data_ptr)

def create_workspace(size, torch_device):
    if size == 0:
        return None
    import torch
    return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)

def create_handle(lib, device, id=0):
    handle = infiniopHandle_t()
    check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id))
    return handle


def destroy_handle(lib, handle):
    check_error(lib.infiniopDestroyHandle(handle))


def rearrange_tensor(tensor, new_strides):
    """
    Given a PyTorch tensor and a list of new strides, return a new PyTorch tensor with the given strides.
    """
    import torch

    shape = tensor.shape

    new_size = [0] * len(shape)
    left = 0
    right = 0
    for i in range(len(shape)):
        if new_strides[i] > 0:
            new_size[i] = (shape[i] - 1) * new_strides[i] + 1
            right += new_strides[i] * (shape[i] - 1)
        else:  # TODO: Support negative strides in the future
            # new_size[i] = (shape[i] - 1) * (-new_strides[i]) + 1
            # left += new_strides[i] * (shape[i] - 1)
            raise ValueError("Negative strides are not supported yet")

    # Create a new tensor with zeros
    new_tensor = torch.zeros(
        (right - left + 1,), dtype=tensor.dtype, device=tensor.device
    )

    # Generate indices for original tensor based on original strides
    indices = [torch.arange(s) for s in shape]
    mesh = torch.meshgrid(*indices, indexing="ij")

    # Flatten indices for linear indexing
    linear_indices = [m.flatten() for m in mesh]

    # Calculate new positions based on new strides
    new_positions = sum(
        linear_indices[i] * new_strides[i] for i in range(len(shape))
    ).to(tensor.device)
    offset = -left
    new_positions += offset

    # Copy the original data to the new tensor
    new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1))
    new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))

    return new_tensor