clip.py 5.16 KB
Newer Older
goldenfox2025's avatar
goldenfox2025 committed
1
2
3
4
#!/usr/bin/env python3

import torch
import ctypes
5
from ctypes import c_uint64
goldenfox2025's avatar
goldenfox2025 committed
6
from libinfiniop import (
7
8
    LIBINFINIOP,
    TestTensor,
goldenfox2025's avatar
goldenfox2025 committed
9
10
11
12
13
14
15
    get_test_devices,
    check_error,
    test_operator,
    get_args,
    debug,
    get_tolerance,
    profile_operation,
16
17
18
19
20
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
goldenfox2025's avatar
goldenfox2025 committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
)
from enum import Enum, auto

# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
    # shape, x_stride, y_stride, min_val, max_val
    # 基本形状测试
    ((10,), None, None, -1.0, 1.0),
    ((5, 10), None, None, -1.0, 1.0),
    ((2, 3, 4), None, None, -1.0, 1.0),
    # 不同的min_val和max_val
    ((10,), None, None, 0.0, 2.0),
    ((5, 10), None, None, 0.0, 2.0),
    ((2, 3, 4), None, None, 0.0, 2.0),
    ((10,), None, None, -2.0, 0.0),
    ((5, 10), None, None, -2.0, 0.0),
    ((2, 3, 4), None, None, -2.0, 0.0),
    # 奇怪形状测试
PanZezhong's avatar
PanZezhong committed
42
43
    ((7, 13), None, None, -1.0, 1.0),  # 质数维度
    ((3, 5, 7), None, None, -1.0, 1.0),  # 三维质数
goldenfox2025's avatar
goldenfox2025 committed
44
    # 非标准形状测试
PanZezhong's avatar
PanZezhong committed
45
46
47
    ((1, 1), None, None, -1.0, 1.0),  # 最小形状
    ((100, 100), None, None, -1.0, 1.0),  # 大形状
    ((16, 16, 16), None, None, -1.0, 1.0),  # 大三维
goldenfox2025's avatar
goldenfox2025 committed
48
49
    # 极端值测试
    ((10,), None, None, -1000.0, 1000.0),  # 大范围
PanZezhong's avatar
PanZezhong committed
50
51
    ((10,), None, None, -0.001, 0.001),  # 小范围
    ((10,), None, None, 0.0, 0.0),  # min=max
goldenfox2025's avatar
goldenfox2025 committed
52
53
]

goldenfox2025's avatar
goldenfox2025 committed
54

55
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
goldenfox2025's avatar
goldenfox2025 committed
56
57
58


_TOLERANCE_MAP = {
59
60
    InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3},
    InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-6},
61
    InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-3},
goldenfox2025's avatar
goldenfox2025 committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
}


class Inplace(Enum):
    OUT_OF_PLACE = auto()
    INPLACE_X = auto()


_INPLACE = [
    Inplace.INPLACE_X,
    Inplace.OUT_OF_PLACE,
]

_TEST_CASES = [
    test_case + (inplace_item,)
    for test_case in _TEST_CASES_
    for inplace_item in _INPLACE
]

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


87
88
def clip(y, x, min_val, max_val):
    torch.clamp(x, min_val, max_val, out=y)
goldenfox2025's avatar
goldenfox2025 committed
89
90
91
92


def test(
    handle,
93
    device,
goldenfox2025's avatar
goldenfox2025 committed
94
95
96
97
98
99
    shape,
    x_stride=None,
    y_stride=None,
    min_val=-1.0,
    max_val=1.0,
    inplace=Inplace.OUT_OF_PLACE,
100
    dtype=InfiniDtype.F32,
PanZezhong's avatar
PanZezhong committed
101
    sync=None,
goldenfox2025's avatar
goldenfox2025 committed
102
):
103
104
105
    x = TestTensor(shape, x_stride, dtype, device)
    min_ = TestTensor(
        shape, [0 for _ in shape], dtype, device, mode="zeros", bias=min_val
goldenfox2025's avatar
goldenfox2025 committed
106
    )
107
108
109
110
    max_ = TestTensor(
        shape, [0 for _ in shape], dtype, device, mode="zeros", bias=max_val
    )

goldenfox2025's avatar
goldenfox2025 committed
111
    if inplace == Inplace.INPLACE_X:
112
113
        if x_stride != y_stride:
            return
goldenfox2025's avatar
goldenfox2025 committed
114
115
        y = x
    else:
116
117
118
119
120
121
122
123
124
125
126
        y = TestTensor(shape, y_stride, dtype, device)

    if y.is_broadcast():
        return

    print(
        f"Testing Clip on {InfiniDeviceNames[device]} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} "
        f"min_val:{min_val} max_val:{max_val} dtype:{InfiniDtypeNames[dtype]} inplace:{inplace}"
    )

    clip(y.torch_tensor(), x.torch_tensor(), min_val, max_val)
PanZezhong's avatar
PanZezhong committed
127
128
129
130

    if sync is not None:
        sync()

131
132
    descriptor = infiniopOperatorDescriptor_t()

goldenfox2025's avatar
goldenfox2025 committed
133
    check_error(
134
        LIBINFINIOP.infiniopCreateClipDescriptor(
PanZezhong's avatar
PanZezhong committed
135
136
            handle,
            ctypes.byref(descriptor),
137
138
139
140
            y.descriptor,
            x.descriptor,
            min_.descriptor,
            max_.descriptor,
goldenfox2025's avatar
goldenfox2025 committed
141
142
143
144
145
        )
    )

    workspace_size = c_uint64(0)
    check_error(
146
147
148
        LIBINFINIOP.infiniopGetClipWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
goldenfox2025's avatar
goldenfox2025 committed
149
    )
150
    workspace = TestWorkspace(workspace_size.value, x.device)
goldenfox2025's avatar
goldenfox2025 committed
151
152
153

    def lib_clip():
        check_error(
154
            LIBINFINIOP.infiniopClip(
goldenfox2025's avatar
goldenfox2025 committed
155
                descriptor,
156
                workspace.data() if workspace is not None else None,
goldenfox2025's avatar
goldenfox2025 committed
157
                workspace_size.value,
158
159
160
161
                y.data(),
                x.data(),
                min_.data(),
                max_.data(),
goldenfox2025's avatar
goldenfox2025 committed
162
163
164
165
166
167
                None,
            )
        )

    lib_clip()

168
169
170
    # Destroy the tensor descriptors
    for tensor in [x, y, min_, max_]:
        tensor.destroy_desc()
goldenfox2025's avatar
goldenfox2025 committed
171
172

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
173
174
175
    if DEBUG:
        debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
    assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
goldenfox2025's avatar
goldenfox2025 committed
176
177
178
179

    # Profiling workflow
    if PROFILE:
        # fmt: off
180
181
        profile_operation("PyTorch", lambda: clip(y.torch_tensor(), x.torch_tensor(), min_val, max_val), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_clip(), device, NUM_PRERUN, NUM_ITERATIONS)
goldenfox2025's avatar
goldenfox2025 committed
182
183
        # fmt: on

184
    check_error(LIBINFINIOP.infiniopDestroyClipDescriptor(descriptor))
goldenfox2025's avatar
goldenfox2025 committed
185
186
187
188
189
190
191
192
193
194
195


if __name__ == "__main__":
    args = get_args()
    # Configure testing options
    DEBUG = args.debug
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    for device in get_test_devices(args):
196
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
goldenfox2025's avatar
goldenfox2025 committed
197
198

    print("\033[92mTest passed!\033[0m")