rearrange.py 4.56 KB
Newer Older
PanZezhongQY's avatar
PanZezhongQY committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import ctypes
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
    open_lib,
    to_tensor,
    CTensor,
    DeviceEnum,
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
    create_handle,
    destroy_handle,
    check_error,
    rearrange_tensor,
)

from operatorspy.tests.test_utils import get_args
import torch


class RerrangeDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)


def test(
    lib,
    handle,
    torch_device,
    x_shape,
    x_stride,
    y_shape,
    y_stride,
    x_dtype=torch.float16,
):
    print(
        f"Testing Rerrange on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} y_shape:{y_shape} y_stride:{y_stride} x_dtype:{x_dtype}"
    )
    x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
    y = torch.zeros(y_shape, dtype=x_dtype).to(torch_device)
    if x_stride is not None:
        x = rearrange_tensor(x, x_stride)
    if y_stride is not None:
        y = rearrange_tensor(y, y_stride)
    x_tensor = to_tensor(x, lib)
    y_tensor = to_tensor(y, lib)

    descriptor = infiniopRearrangeDescriptor_t()
    check_error(
        lib.infiniopCreateRearrangeDescriptor(
            handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
    x_tensor.descriptor.contents.invalidate()
    y_tensor.descriptor.contents.invalidate()

64
    check_error(lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None))
PanZezhongQY's avatar
PanZezhongQY committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    assert torch.allclose(x, y, atol=0, rtol=1e-3)
    check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor))


def test_cpu(lib, test_cases):
    device = DeviceEnum.DEVICE_CPU
    handle = create_handle(lib, device)
    for test_case in test_cases:
        x_shape, x_stride = test_case[0]
        y_shape, y_stride = test_case[1]
        test(lib, handle, "cpu", x_shape, x_stride, y_shape, y_stride)
    destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
    device = DeviceEnum.DEVICE_CUDA
    handle = create_handle(lib, device)
    for test_case in test_cases:
        x_shape, x_stride = test_case[0]
        y_shape, y_stride = test_case[1]
        test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
    destroy_handle(lib, handle)

88

PanZezhongQY's avatar
PanZezhongQY committed
89
90
def test_bang(lib, test_cases):
    import torch_mlu
91

PanZezhongQY's avatar
PanZezhongQY committed
92
93
94
95
96
97
98
99
    device = DeviceEnum.DEVICE_BANG
    handle = create_handle(lib, device)
    for test_case in test_cases:
        x_shape, x_stride = test_case[0]
        y_shape, y_stride = test_case[1]
        test(lib, handle, "mlu", x_shape, x_stride, y_shape, y_stride)
    destroy_handle(lib, handle)

100

PanZezhongQY's avatar
PanZezhongQY committed
101
102
103
104
105
106
107
108
109
def test_ascend(lib, test_cases):
    import torch_npu

    device = DeviceEnum.DEVICE_ASCEND
    handle = create_handle(lib, device)
    for test_case in test_cases:
        x_shape, x_stride = test_case[0]
        y_shape, y_stride = test_case[1]
        test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
110
111
    destroy_handle(lib, handle)

PanZezhongQY's avatar
PanZezhongQY committed
112
113
114
115
116
117
118
119
120
121
122
123

if __name__ == "__main__":
    args = get_args()
    test_cases = [
        # ((src_shape, src_stride), (dst_shape, dst_stride))
        (((2, 4, 32), None), ((2, 4, 32), (256, 64, 1))),
        (((32, 6, 64), (64, 2560, 1)), ((32, 6, 64), None)),
        (((4, 6, 64), (64, 2560, 1)), ((4, 6, 64), (131072, 64, 1))),
        (((1, 32, 64), (2048, 64, 1)), ((1, 32, 64), (2048, 64, 1))),
        (((32, 1, 64), (64, 2560, 1)), ((32, 1, 64), (64, 64, 1))),
        (((4, 1, 64), (64, 2560, 1)), ((4, 1, 64), (64, 11264, 1))),
        (((64,), (1,)), ((64,), (1,))),
124
    ]
PanZezhongQY's avatar
PanZezhongQY committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    lib = open_lib()
    lib.infiniopCreateRearrangeDescriptor.restype = c_int32
    lib.infiniopCreateRearrangeDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRearrangeDescriptor_t),
        infiniopTensorDescriptor_t,
        infiniopTensorDescriptor_t,
    ]
    lib.infiniopRearrange.restype = c_int32
    lib.infiniopRearrange.argtypes = [
        infiniopRearrangeDescriptor_t,
        c_void_p,
        c_void_p,
        c_void_p,
    ]
    lib.infiniopDestroyRearrangeDescriptor.restype = c_int32
    lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t]
    if args.cpu:
        test_cpu(lib, test_cases)
    if args.cuda:
        test_cuda(lib, test_cases)
    if args.bang:
        test_bang(lib, test_cases)
    if args.ascend:
        test_ascend(lib, test_cases)
    print("\033[92mTest passed!\033[0m")