rearrange.py 4.93 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from libinfiniop import (
4
5
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
6
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
7
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
8
9
10
11
12
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
13
14
15
16
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
17
18
)

19

pwhMass's avatar
pwhMass committed
20
21
def row_major_strides(shape):
    """生成张量的行优先(C风格)stride
22

pwhMass's avatar
pwhMass committed
23
24
    Args:
        shape: 张量形状
25

pwhMass's avatar
pwhMass committed
26
27
28
29
30
31
32
33
34
35
36
    Returns:
        行优先strides列表
    """
    # 行优先 (C风格,从最后一维到第一维)
    stride = 1
    strides = [1]
    for dim in reversed(shape[1:]):
        stride *= dim
        strides.insert(0, stride)
    return strides

37

pwhMass's avatar
pwhMass committed
38
39
def column_major_strides(shape):
    """生成张量的列优先(Fortran风格)stride
40

pwhMass's avatar
pwhMass committed
41
42
    Args:
        shape: 张量形状
43

pwhMass's avatar
pwhMass committed
44
45
46
47
48
49
50
51
52
53
54
55
    Returns:
        列优先strides列表
    """
    # 列优先 (Fortran风格,从第一维到最后一维)
    stride = 1
    strides = [stride]
    for dim in shape[:-1]:
        stride *= dim
        strides.append(stride)
    return strides


xgqdut2016's avatar
xgqdut2016 committed
56
57
58
59
60
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
pwhMass's avatar
pwhMass committed
61
    # (shape, x_stride, y_stride)
62
63
64
65
66
67
    ((100, 100), (1, 100), (100, 1)),  # shape  # x_stride  # y_stride
    ((4, 4), (1, 4), (4, 1)),  # shape  # x_stride  # y_stride
    ((4, 6, 64), (64, 4 * 64, 1), (6 * 64, 64, 1)),  # shape  # x_stride  # y_stride
    ((2000, 2000), (1, 2000), (2000, 1)),  # shape  # x_stride  # y_stride
    ((2001, 2001), (1, 2001), (2001, 1)),  # shape  # x_stride  # y_stride
    ((2, 2, 2, 4), (16, 8, 4, 1), (16, 8, 1, 2)),  # shape  # x_stride  # y_stride
pwhMass's avatar
pwhMass committed
68
    (
69
70
71
        (3, 4, 7, 53, 9),  # shape
        row_major_strides((3, 4, 7, 53, 9)),  # x_stride
        column_major_strides((3, 4, 7, 53, 9)),  # y_stride
pwhMass's avatar
pwhMass committed
72
73
    ),
    (
74
        (3, 4, 50, 50, 5, 7),  # shape
pwhMass's avatar
pwhMass committed
75
        row_major_strides((3, 4, 50, 50, 5, 7)),  # x_stride
76
        column_major_strides((3, 4, 50, 50, 5, 7)),  # y_stride
pwhMass's avatar
pwhMass committed
77
    ),
zhangyue's avatar
zhangyue committed
78
    ((15, 10752), (0, 1), (10752, 1)),
79
80
81
    ((2, 2, 2, 2, 2, 2), (4, 8, 16, 32, 64, 128), (64, 32, 16, 8, 4, 2)),  # shape  # x_stride  # y_stride
    ((8, 4, 20, 64), (5120, 64, 256, 1), None),  # shape  # x_stride  # y_stride
    ((8, 4, 20, 64), (5120, 64, 256, 1), (1048576, 262144, 64, 1)),  # shape  # x_stride  # y_stride
xgqdut2016's avatar
xgqdut2016 committed
82
83
84
]

# Data types used for testing
85
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
86
87
88

# Tolerance map for different data types
_TOLERANCE_MAP = {
89
90
    InfiniDtype.F16: {"atol": 0, "rtol": 0},
    InfiniDtype.F32: {"atol": 0, "rtol": 0},
xgqdut2016's avatar
xgqdut2016 committed
91
92
93
94
95
96
97
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
98

99
def rearrange_torch(y, x, x_shape, y_stride):
100
101
    if y_stride is None:
        y_stride = row_major_strides(x_shape)
102
    y.set_(y.untyped_storage(), 0, x_shape, y_stride)
zhangyue's avatar
zhangyue committed
103
    y.copy_(x.expand_as(y))
104
105


PanZezhongQY's avatar
PanZezhongQY committed
106
def test(
107
    handle, torch_device, shape, x_stride, y_stride, dtype=InfiniDtype.F16, sync=None
PanZezhongQY's avatar
PanZezhongQY committed
108
109
):
    print(
110
        f"Testing Rerrange on {InfiniDeviceNames[torch_device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}"
PanZezhongQY's avatar
PanZezhongQY committed
111
    )
xgqdut2016's avatar
xgqdut2016 committed
112

113
114
    x = TestTensor(shape, x_stride, dtype, device)
    y = TestTensor(shape, y_stride, dtype, device, mode="ones")
xgqdut2016's avatar
xgqdut2016 committed
115

116
    rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride)
117

118
119
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
120

121
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
122
    check_error(
123
124
        LIBINFINIOP.infiniopCreateRearrangeDescriptor(
            handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
125
126
127
128
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
129
130
    for tensor in [x, y]:
        tensor.destroy_desc()
xgqdut2016's avatar
xgqdut2016 committed
131
132

    def lib_rearrange():
133
        check_error(LIBINFINIOP.infiniopRearrange(descriptor, y.data(), x.data(), None))
xgqdut2016's avatar
xgqdut2016 committed
134

xgqdut2016's avatar
xgqdut2016 committed
135
    lib_rearrange()
xgqdut2016's avatar
xgqdut2016 committed
136

xgqdut2016's avatar
xgqdut2016 committed
137
138
139
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
140
141
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
142
143
144
145

    # Profiling workflow
    if PROFILE:
        # fmt: off
146
147
        profile_operation("PyTorch", lambda: rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rearrange(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
148
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
149

150
    check_error(LIBINFINIOP.infiniopDestroyRearrangeDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
151
152
153
154


if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
155
156
157
158
159
160
161
162
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
163
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
164

PanZezhongQY's avatar
PanZezhongQY committed
165
    print("\033[92mTest passed!\033[0m")