rearrange.py 4.58 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
from libinfiniop import (
4
5
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
6
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
7
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
8
9
10
11
12
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
13
14
15
16
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
17
18
)

19

pwhMass's avatar
pwhMass committed
20
21
def row_major_strides(shape):
    """生成张量的行优先(C风格)stride
22

pwhMass's avatar
pwhMass committed
23
24
    Args:
        shape: 张量形状
25

pwhMass's avatar
pwhMass committed
26
27
28
29
30
31
32
33
34
35
36
    Returns:
        行优先strides列表
    """
    # 行优先 (C风格,从最后一维到第一维)
    stride = 1
    strides = [1]
    for dim in reversed(shape[1:]):
        stride *= dim
        strides.insert(0, stride)
    return strides

37

pwhMass's avatar
pwhMass committed
38
39
def column_major_strides(shape):
    """生成张量的列优先(Fortran风格)stride
40

pwhMass's avatar
pwhMass committed
41
42
    Args:
        shape: 张量形状
43

pwhMass's avatar
pwhMass committed
44
45
46
47
48
49
50
51
52
53
54
55
    Returns:
        列优先strides列表
    """
    # 列优先 (Fortran风格,从第一维到最后一维)
    stride = 1
    strides = [stride]
    for dim in shape[:-1]:
        stride *= dim
        strides.append(stride)
    return strides


xgqdut2016's avatar
xgqdut2016 committed
56
57
58
59
60
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
pwhMass's avatar
pwhMass committed
61
    # (shape, x_stride, y_stride)
62
63
64
65
66
67
    ((100, 100), (1, 100), (100, 1)),  # shape  # x_stride  # y_stride
    ((4, 4), (1, 4), (4, 1)),  # shape  # x_stride  # y_stride
    ((4, 6, 64), (64, 4 * 64, 1), (6 * 64, 64, 1)),  # shape  # x_stride  # y_stride
    ((2000, 2000), (1, 2000), (2000, 1)),  # shape  # x_stride  # y_stride
    ((2001, 2001), (1, 2001), (2001, 1)),  # shape  # x_stride  # y_stride
    ((2, 2, 2, 4), (16, 8, 4, 1), (16, 8, 1, 2)),  # shape  # x_stride  # y_stride
pwhMass's avatar
pwhMass committed
68
    (
69
70
71
        (3, 4, 7, 53, 9),  # shape
        row_major_strides((3, 4, 7, 53, 9)),  # x_stride
        column_major_strides((3, 4, 7, 53, 9)),  # y_stride
pwhMass's avatar
pwhMass committed
72
73
    ),
    (
74
        (3, 4, 50, 50, 5, 7),  # shape
pwhMass's avatar
pwhMass committed
75
        row_major_strides((3, 4, 50, 50, 5, 7)),  # x_stride
76
        column_major_strides((3, 4, 50, 50, 5, 7)),  # y_stride
pwhMass's avatar
pwhMass committed
77
    ),
zhangyue's avatar
zhangyue committed
78
    ((15, 10752), (0, 1), (10752, 1)),
xgqdut2016's avatar
xgqdut2016 committed
79
80
81
]

# Data types used for testing
82
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
xgqdut2016's avatar
xgqdut2016 committed
83
84
85

# Tolerance map for different data types
_TOLERANCE_MAP = {
86
87
    InfiniDtype.F16: {"atol": 0, "rtol": 0},
    InfiniDtype.F32: {"atol": 0, "rtol": 0},
xgqdut2016's avatar
xgqdut2016 committed
88
89
90
91
92
93
94
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000

PanZezhongQY's avatar
PanZezhongQY committed
95

96
97
def rearrange_torch(y, x, x_shape, y_stride):
    y.set_(y.untyped_storage(), 0, x_shape, y_stride)
zhangyue's avatar
zhangyue committed
98
    y.copy_(x.expand_as(y))
99
100


PanZezhongQY's avatar
PanZezhongQY committed
101
def test(
102
    handle, torch_device, shape, x_stride, y_stride, dtype=InfiniDtype.F16, sync=None
PanZezhongQY's avatar
PanZezhongQY committed
103
104
):
    print(
105
        f"Testing Rerrange on {InfiniDeviceNames[torch_device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{InfiniDtypeNames[dtype]}"
PanZezhongQY's avatar
PanZezhongQY committed
106
    )
xgqdut2016's avatar
xgqdut2016 committed
107

108
109
    x = TestTensor(shape, x_stride, dtype, device)
    y = TestTensor(shape, y_stride, dtype, device, mode="ones")
xgqdut2016's avatar
xgqdut2016 committed
110

111
    rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride)
112

113
114
    if sync is not None:
        sync()
xgqdut2016's avatar
xgqdut2016 committed
115

116
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
117
    check_error(
118
119
        LIBINFINIOP.infiniopCreateRearrangeDescriptor(
            handle, ctypes.byref(descriptor), y.descriptor, x.descriptor
PanZezhongQY's avatar
PanZezhongQY committed
120
121
122
123
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
124
125
    for tensor in [x, y]:
        tensor.destroy_desc()
xgqdut2016's avatar
xgqdut2016 committed
126
127

    def lib_rearrange():
128
        check_error(LIBINFINIOP.infiniopRearrange(descriptor, y.data(), x.data(), None))
xgqdut2016's avatar
xgqdut2016 committed
129

xgqdut2016's avatar
xgqdut2016 committed
130
    lib_rearrange()
xgqdut2016's avatar
xgqdut2016 committed
131

xgqdut2016's avatar
xgqdut2016 committed
132
133
134
    # Validate results
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
135
136
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
xgqdut2016's avatar
xgqdut2016 committed
137
138
139
140

    # Profiling workflow
    if PROFILE:
        # fmt: off
141
142
        profile_operation("PyTorch", lambda: rearrange_torch(y.torch_tensor(), x.torch_tensor(), shape, y_stride), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_rearrange(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
143
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
144

145
    check_error(LIBINFINIOP.infiniopDestroyRearrangeDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
146
147
148
149


if __name__ == "__main__":
    args = get_args()
xgqdut2016's avatar
xgqdut2016 committed
150
151
152
153
154
155
156
157
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
158
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
xgqdut2016's avatar
xgqdut2016 committed
159

PanZezhongQY's avatar
PanZezhongQY committed
160
    print("\033[92mTest passed!\033[0m")