rearrange.py 4.54 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from libinfiniop import (
4
5
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
6
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
7
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
8
9
10
11
12
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
13
14
15
16
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
17
18
)

19

pwhMass's avatar
pwhMass committed
20
21
def row_major_strides(shape):
    """生成张量的行优先(C风格)stride
22

pwhMass's avatar
pwhMass committed
23
24
    Args:
        shape: 张量形状
25

pwhMass's avatar
pwhMass committed
26
27
28
29
30
31
32
33
34
35
36
    Returns:
        行优先strides列表
    """
    # 行优先 (C风格,从最后一维到第一维)
    stride = 1
    strides = [1]
    for dim in reversed(shape[1:]):
        stride *= dim
        strides.insert(0, stride)
    return strides

37

pwhMass's avatar
pwhMass committed
38
39
def column_major_strides(shape):
    """生成张量的列优先(Fortran风格)stride
40

pwhMass's avatar
pwhMass committed
41
42
    Args:
        shape: 张量形状
43

pwhMass's avatar
pwhMass committed
44
45
46
47
48
49
50
51
52
53
54
55
    Returns:
        列优先strides列表
    """
    # 列优先 (Fortran风格,从第一维到最后一维)
    stride = 1
    strides = [stride]
    for dim in shape[:-1]:
        stride *= dim
        strides.append(stride)
    return strides


xgqdut2016's avatar
xgqdut2016 committed
56
57
58
59
60
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
pwhMass's avatar
pwhMass committed
61
    # (shape, x_stride, y_stride)
62
63
64
65
66
67
    ((100, 100), (1, 100), (100, 1)),  # shape  # x_stride  # y_stride
    ((4, 4), (1, 4), (4, 1)),  # shape  # x_stride  # y_stride
    ((4, 6, 64), (64, 4 * 64, 1), (6 * 64, 64, 1)),  # shape  # x_stride  # y_stride
    ((2000, 2000), (1, 2000), (2000, 1)),  # shape  # x_stride  # y_stride
    ((2001, 2001), (1, 2001), (2001, 1)),  # shape  # x_stride  # y_stride
    ((2, 2, 2, 4), (16, 8, 4, 1), (16, 8, 1, 2)),  # shape  # x_stride  # y_stride
pwhMass's avatar
pwhMass committed
68
    (
69
70
71
        (3, 4, 7, 53, 9),  # shape
        row_major_strides((3, 4, 7, 53, 9)),  # x_stride
        column_major_strides((3, 4, 7, 53, 9)),  # y_stride
pwhMass's avatar
pwhMass committed
72
73
    ),
    (
74
        (3, 4, 50, 50, 5, 7),  # shape
pwhMass's avatar
pwhMass committed
75
        row_major_strides((3, 4, 50, 50, 5, 7)),  # x_stride
76
        column_major_strides((3, 4, 50, 50, 5, 7)),  # y_stride
pwhMass's avatar
pwhMass committed
77
    ),
xgqdut2016's avatar
xgqdut2016 committed
78
79
80
]

# Data types used for testing
81
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
82
83
84

# Tolerance map for different data types
_TOLERANCE_MAP = {
85
86
    InfiniDtype.F16: {"atol": 0, "rtol": 0},
    InfiniDtype.F32: {"atol": 0, "rtol": 0},
xgqdut2016's avatar
xgqdut2016 committed
87
88
89
90
91
92
93
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
94

95
96
97
def rearrange_torch(y, x, x_shape, y_stride):
    y.set_(y.untyped_storage(), 0, x_shape, y_stride)
    y[:] = x.view_as(y)
98
99


PanZezhongQY's avatar
PanZezhongQY committed
100
def test(
101
    handle, torch_device, shape, x_stride, y_stride, dtype=InfiniDtype.F16, sync=None
PanZezhongQY's avatar
PanZezhongQY committed
102
103
):
    print(
104
        f"Testing Rerrange on {InfiniDeviceNames[torch_device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}"
PanZezhongQY's avatar
PanZezhongQY committed
105
    )
xgqdut2016's avatar
xgqdut2016 committed
106

107
108
    x = TestTensor(shape, x_stride, dtype, device)
    y = TestTensor(shape, y_stride, dtype, device, mode="ones")
xgqdut2016's avatar
xgqdut2016 committed
109

110
    rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride)
111

112
113
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
114

115
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
116
    check_error(
117
118
        LIBINFINIOP.infiniopCreateRearrangeDescriptor(
            handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
119
120
121
122
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
123
124
    for tensor in [x, y]:
        tensor.destroy_desc()
xgqdut2016's avatar
xgqdut2016 committed
125
126

    def lib_rearrange():
127
        check_error(LIBINFINIOP.infiniopRearrange(descriptor, y.data(), x.data(), None))
xgqdut2016's avatar
xgqdut2016 committed
128

xgqdut2016's avatar
xgqdut2016 committed
129
    lib_rearrange()
xgqdut2016's avatar
xgqdut2016 committed
130

xgqdut2016's avatar
xgqdut2016 committed
131
132
133
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
134
135
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
136
137
138
139

    # Profiling workflow
    if PROFILE:
        # fmt: off
140
141
        profile_operation("PyTorch", lambda: rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rearrange(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
142
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
143

144
    check_error(LIBINFINIOP.infiniopDestroyRearrangeDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
145
146
147
148


if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
149
150
151
152
153
154
155
156
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
157
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
158

PanZezhongQY's avatar
PanZezhongQY committed
159
    print("\033[92mTest passed!\033[0m")