random_sample.py 6.31 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
    InfiniDtype,
PanZezhongQY's avatar
PanZezhongQY committed
6
7
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
8
9
10
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
11
12
    check_error,
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
13
14
    test_operator,
    get_args,
xgqdut2016's avatar
xgqdut2016 committed
15
    debug_all,
xgqdut2016's avatar
xgqdut2016 committed
16
17
    get_tolerance,
    profile_operation,
xgqdut2016's avatar
xgqdut2016 committed
18
    synchronize_device,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
25
26
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # voc, random_val, topp, topk, temperature
xgqdut2016's avatar
xgqdut2016 committed
27
28
29
30
31
32
33
34
35
36
    (512, 0.8, 0.8, 3, 0.5),
    (4096, 0.05, 0.9, 5, 1.0),
    (16384, 0.15, 0.85, 10, 2.0),
    (512, 0.08, 0, 3, 0.5),
    (4096, 0.5, 0.9, 1, 1.0),
    (16384, 0.15, 0, 1, 2.0),
    (16384, 0.15, 0, 1, 2.0),
    (32000, 0.08, 0.8, 50, 1.0),
    (32000, 0.08, 1.0, 25, 1.0),
    # (119696, 0.01, 1.0, 100, 1.0),
xgqdut2016's avatar
xgqdut2016 committed
37
38
39
]

# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
40
41
42
43
44
_TENSOR_DTYPES = [torch.float16]

_TOLERANCE_MAP = {
    torch.float16: {"atol": 0, "rtol": 0},
}
xgqdut2016's avatar
xgqdut2016 committed
45

46

xgqdut2016's avatar
xgqdut2016 committed
47
DEBUG = False
xgqdut2016's avatar
xgqdut2016 committed
48
49
50
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
51
52
53
54
55
56
57
58
59


class RandomSampleDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)


60
def random_sample(data, random_val, topp, topk, voc, temperature):
61
    if topp > 0 and topk > 1:
62
        sorted_vals, sorted_indices = torch.sort(data, descending=True)
63

64
65
66
        scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
        probs = torch.softmax(scaled_vals, dim=0)
        cum_probs = torch.cumsum(probs, dim=0)
67

68
69
        k_index = min(topk, voc) - 1
        threshold = min(cum_probs[k_index], topp) * random_val
70
71
72
73
74
75
76
        
        try:
            idx = torch.searchsorted(cum_probs, threshold)
        except Exception:
            # Fallback for manual search if torch.searchsorted is not supported
            indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
            idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device)
77
        return sorted_indices[idx]
78

79
    return torch.argmax(data)
zhangyue's avatar
zhangyue committed
80
81


82
83
84
85
86
87
88
89
90
def test(
    lib,
    handle,
    torch_device,
    voc,
    random_val,
    topp,
    topk,
    temperature,
91
    dtype=torch.float16,
92
    sync=None,
93
):
94
95
96
    print(
        f"Testing RandomSample on {torch_device} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{dtype}"
    )
xgqdut2016's avatar
xgqdut2016 committed
97

PanZezhongQY's avatar
PanZezhongQY committed
98
99
    data = torch.arange(voc).float() * 0.0001
    _perm = torch.randperm(voc)
100
    data = data[_perm].to(dtype).to(torch_device)
101
102

    ans = random_sample(
103
        data, random_val, topp, topk, voc, temperature
104
    )  # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
xgqdut2016's avatar
xgqdut2016 committed
105

106
    indices = torch.zeros([], dtype=torch.int64).to(torch_device)
xgqdut2016's avatar
xgqdut2016 committed
107
108
109

    x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]

110
    indices_tensor.descriptor.contents.dt = InfiniDtype.U64  # treat int64 as uint64
PanZezhongQY's avatar
PanZezhongQY committed
111

112
113
114
    if sync is not None:
        sync()

PanZezhongQY's avatar
PanZezhongQY committed
115
116
117
    descriptor = infiniopRandomSampleDescriptor_t()
    check_error(
        lib.infiniopCreateRandomSampleDescriptor(
118
119
120
121
            handle,
            ctypes.byref(descriptor),
            indices_tensor.descriptor,
            x_tensor.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
122
123
124
125
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
126
    for tensor in [x_tensor, indices_tensor]:
127
        tensor.destroyDesc(lib)
PanZezhongQY's avatar
PanZezhongQY committed
128
129
130
131
132
133
134

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRandomSampleWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
135
    workspace = create_workspace(workspace_size.value, torch_device)
136

xgqdut2016's avatar
xgqdut2016 committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    def lib_random_sample():
        check_error(
            lib.infiniopRandomSample(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                indices_tensor.data,
                x_tensor.data,
                random_val,
                topp,
                topk,
                temperature,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
151
152
        )

xgqdut2016's avatar
xgqdut2016 committed
153
154
    lib_random_sample()

xgqdut2016's avatar
xgqdut2016 committed
155
156
157
158
159
160
    if torch_device == "npu":
        synchronize_device(torch_device)

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug_all(
161
            (indices.type(ans.dtype), data[indices]),
xgqdut2016's avatar
xgqdut2016 committed
162
163
164
165
166
            (ans, data[ans]),
            "or",
            atol=atol,
            rtol=rtol,
        )
167
    assert indices.type(ans.dtype) == ans or data[ans] == data[indices]
xgqdut2016's avatar
xgqdut2016 committed
168

xgqdut2016's avatar
xgqdut2016 committed
169
170
171
    # Profiling workflow
    if PROFILE:
        # fmt: off
172
        profile_operation("PyTorch", lambda: random_sample(
173
                data, random_val, topp, topk, voc, temperature
xgqdut2016's avatar
xgqdut2016 committed
174
175
176
            ), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
177
178
    check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))

179

PanZezhongQY's avatar
PanZezhongQY committed
180
181
182
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
183

PanZezhongQY's avatar
PanZezhongQY committed
184
185
186
187
188
189
    lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
    lib.infiniopCreateRandomSampleDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRandomSampleDescriptor_t),
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
190

PanZezhongQY's avatar
PanZezhongQY committed
191
192
193
194
195
    lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
    lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
        infiniopRandomSampleDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
196

PanZezhongQY's avatar
PanZezhongQY committed
197
198
199
200
201
202
203
204
205
206
207
208
209
    lib.infiniopRandomSample.restype = c_int32
    lib.infiniopRandomSample.argtypes = [
        infiniopRandomSampleDescriptor_t,
        c_void_p,
        c_uint64,
        c_uint64,
        c_void_p,
        c_float,
        c_float,
        c_int32,
        c_float,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
210

PanZezhongQY's avatar
PanZezhongQY committed
211
212
213
214
215
    lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
    lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
        infiniopRandomSampleDescriptor_t,
    ]

xgqdut2016's avatar
xgqdut2016 committed
216
    DEBUG = args.debug
xgqdut2016's avatar
xgqdut2016 committed
217
218
219
220
221
222
223
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
224

PanZezhongQY's avatar
PanZezhongQY committed
225
    print("\033[92mTest passed!\033[0m")