random_sample.py 6.83 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
xgqdut2016's avatar
xgqdut2016 committed
3
4
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
PanZezhongQY's avatar
PanZezhongQY committed
5
6
    infiniopHandle_t,
    infiniopTensorDescriptor_t,
xgqdut2016's avatar
xgqdut2016 committed
7
8
9
    open_lib,
    to_tensor,
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
10
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
11
    rearrange_if_needed,
PanZezhongQY's avatar
PanZezhongQY committed
12
    create_workspace,
xgqdut2016's avatar
xgqdut2016 committed
13
14
    test_operator,
    get_args,
xgqdut2016's avatar
xgqdut2016 committed
15
    debug_all,
xgqdut2016's avatar
xgqdut2016 committed
16
17
    get_tolerance,
    profile_operation,
xgqdut2016's avatar
xgqdut2016 committed
18
    synchronize_device,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
xgqdut2016's avatar
xgqdut2016 committed
25

xgqdut2016's avatar
xgqdut2016 committed
26
27
_TEST_CASES = [
    # voc, random_val, topp, topk, temperature
xgqdut2016's avatar
xgqdut2016 committed
28
29
30
31
32
33
34
35
36
37
    (512, 0.8, 0.8, 3, 0.5),
    (4096, 0.05, 0.9, 5, 1.0),
    (16384, 0.15, 0.85, 10, 2.0),
    (512, 0.08, 0, 3, 0.5),
    (4096, 0.5, 0.9, 1, 1.0),
    (16384, 0.15, 0, 1, 2.0),
    (16384, 0.15, 0, 1, 2.0),
    (32000, 0.08, 0.8, 50, 1.0),
    (32000, 0.08, 1.0, 25, 1.0),
    # (119696, 0.01, 1.0, 100, 1.0),
xgqdut2016's avatar
xgqdut2016 committed
38
39
40
]

# Data types used for testing
xgqdut2016's avatar
xgqdut2016 committed
41
42
43
44
45
_TENSOR_DTYPES = [torch.float16]

_TOLERANCE_MAP = {
    torch.float16: {"atol": 0, "rtol": 0},
}
xgqdut2016's avatar
xgqdut2016 committed
46
47


xgqdut2016's avatar
xgqdut2016 committed
48
DEBUG = False
xgqdut2016's avatar
xgqdut2016 committed
49
50
51
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
52
53
54
55
56
57
58
59
60
61


class RandomSampleDescriptor(Structure):
    _fields_ = [("device", c_int32)]


infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)


def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
62
    indices = torch.zeros([topk], dtype=torch.int64)
PanZezhongQY's avatar
PanZezhongQY committed
63
64
    dataNp = data.clone().detach()
    sorted_indices = torch.arange(voc)
65

PanZezhongQY's avatar
PanZezhongQY committed
66
67
    for i in range(topk):
        for j in range(i + 1, voc):
68
            if dataNp[i] < dataNp[j]:
PanZezhongQY's avatar
PanZezhongQY committed
69
70
71
72
73
74
75
                tmp = dataNp[i].clone().detach()
                dataNp[i] = dataNp[j].clone().detach()
                dataNp[j] = tmp

                tmpInd = sorted_indices[i].clone().detach()
                sorted_indices[i] = sorted_indices[j].clone().detach()
                sorted_indices[j] = tmpInd
76
77
78
79

    # sorted_indices = torch.argsort(dataNp, descending=True)
    indices = sorted_indices[:topk]

PanZezhongQY's avatar
PanZezhongQY committed
80
    dataNp = dataNp[sorted_indices]
81

PanZezhongQY's avatar
PanZezhongQY committed
82
83
    globalM = dataNp[0]
    dataNp = (dataNp - globalM) / temperature
84
    dataNp = torch.softmax(dataNp.float(), dim=0)
PanZezhongQY's avatar
PanZezhongQY committed
85
86
87
    sum_s = 0
    for end in range(topk):
        sum_s += dataNp[end]
88
        if sum_s >= topp:
PanZezhongQY's avatar
PanZezhongQY committed
89
            break
90
    if end < topk - 1:
PanZezhongQY's avatar
PanZezhongQY committed
91
92
93
        end += 1
    else:
        end = topk
94

PanZezhongQY's avatar
PanZezhongQY committed
95
96
97
98
    sum_s = 0
    for i in range(end):
        sum_s += dataNp[i]
    random_val *= sum_s
99

PanZezhongQY's avatar
PanZezhongQY committed
100
101
102
    sum_s = 0
    for i in range(end):
        sum_s += dataNp[i]
103
        if random_val < sum_s:
PanZezhongQY's avatar
PanZezhongQY committed
104
105
            return indices[i]

106

PanZezhongQY's avatar
PanZezhongQY committed
107
108
109
def random_sample_0(data):
    return torch.argmax(data)

110
111
112
113
114
115
116
117
118
119
120
121
122

def test(
    lib,
    handle,
    torch_device,
    voc,
    random_val,
    topp,
    topk,
    temperature,
    x_dtype=torch.float16,
):
    print(f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}")
xgqdut2016's avatar
xgqdut2016 committed
123

PanZezhongQY's avatar
PanZezhongQY committed
124
125
126
    data = torch.arange(voc).float() * 0.0001
    _perm = torch.randperm(voc)
    data = data[_perm].to(x_dtype).to(torch_device)
127
128
129
130
    if topp > 0 and topk > 1:
        ans = random_sample(
            data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
        )
PanZezhongQY's avatar
PanZezhongQY committed
131
132
    else:
        ans = random_sample_0(data)
xgqdut2016's avatar
xgqdut2016 committed
133

PanZezhongQY's avatar
PanZezhongQY committed
134
    indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
xgqdut2016's avatar
xgqdut2016 committed
135
136
137

    x_tensor, indices_tensor = [to_tensor(tensor, lib) for tensor in [data, indices]]

PanZezhongQY's avatar
PanZezhongQY committed
138
139
140
141
142
    indices_tensor.descriptor.contents.dt = U64  # treat int64 as uint64

    descriptor = infiniopRandomSampleDescriptor_t()
    check_error(
        lib.infiniopCreateRandomSampleDescriptor(
143
144
145
146
            handle,
            ctypes.byref(descriptor),
            indices_tensor.descriptor,
            x_tensor.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
147
148
149
150
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
xgqdut2016's avatar
xgqdut2016 committed
151
152
    for tensor in [x_tensor, indices_tensor]:
        tensor.descriptor.contents.invalidate()
PanZezhongQY's avatar
PanZezhongQY committed
153
154
155
156
157
158
159

    workspace_size = c_uint64(0)
    check_error(
        lib.infiniopGetRandomSampleWorkspaceSize(
            descriptor, ctypes.byref(workspace_size)
        )
    )
160
    workspace = create_workspace(workspace_size.value, torch_device)
xgqdut2016's avatar
xgqdut2016 committed
161

xgqdut2016's avatar
xgqdut2016 committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    def lib_random_sample():
        check_error(
            lib.infiniopRandomSample(
                descriptor,
                workspace.data_ptr() if workspace is not None else None,
                workspace_size.value,
                indices_tensor.data,
                x_tensor.data,
                random_val,
                topp,
                topk,
                temperature,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
176
177
        )

xgqdut2016's avatar
xgqdut2016 committed
178
179
180
181
182
183
184
185
186
187
188
189
    if torch_device == "npu":
        synchronize_device(torch_device)

    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug_all(
            (indices[0].type(ans.dtype), data[indices[0]]),
            (ans, data[ans]),
            "or",
            atol=atol,
            rtol=rtol,
        )
PanZezhongQY's avatar
PanZezhongQY committed
190
    assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
xgqdut2016's avatar
xgqdut2016 committed
191

xgqdut2016's avatar
xgqdut2016 committed
192
193
194
195
196
197
198
199
200
201
202
203
    # Profiling workflow
    if PROFILE:
        # fmt: off
        if topp > 0 and topk > 1:
            profile_operation("PyTorch", lambda: random_sample(
                data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
            ), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        else:
            profile_operation("PyTorch", lambda: random_sample_0(data), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        
        profile_operation("    lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
        # fmt: on
PanZezhongQY's avatar
PanZezhongQY committed
204
205
    check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))

206

PanZezhongQY's avatar
PanZezhongQY committed
207
208
209
if __name__ == "__main__":
    args = get_args()
    lib = open_lib()
xgqdut2016's avatar
xgqdut2016 committed
210

PanZezhongQY's avatar
PanZezhongQY committed
211
212
213
214
215
216
    lib.infiniopCreateRandomSampleDescriptor.restype = c_int32
    lib.infiniopCreateRandomSampleDescriptor.argtypes = [
        infiniopHandle_t,
        POINTER(infiniopRandomSampleDescriptor_t),
        infiniopTensorDescriptor_t,
    ]
xgqdut2016's avatar
xgqdut2016 committed
217

PanZezhongQY's avatar
PanZezhongQY committed
218
219
220
221
222
    lib.infiniopGetRandomSampleWorkspaceSize.restype = c_int32
    lib.infiniopGetRandomSampleWorkspaceSize.argtypes = [
        infiniopRandomSampleDescriptor_t,
        POINTER(c_uint64),
    ]
xgqdut2016's avatar
xgqdut2016 committed
223

PanZezhongQY's avatar
PanZezhongQY committed
224
225
226
227
228
229
230
231
232
233
234
235
236
    lib.infiniopRandomSample.restype = c_int32
    lib.infiniopRandomSample.argtypes = [
        infiniopRandomSampleDescriptor_t,
        c_void_p,
        c_uint64,
        c_uint64,
        c_void_p,
        c_float,
        c_float,
        c_int32,
        c_float,
        c_void_p,
    ]
xgqdut2016's avatar
xgqdut2016 committed
237

PanZezhongQY's avatar
PanZezhongQY committed
238
239
240
241
242
    lib.infiniopDestroyRandomSampleDescriptor.restype = c_int32
    lib.infiniopDestroyRandomSampleDescriptor.argtypes = [
        infiniopRandomSampleDescriptor_t,
    ]

xgqdut2016's avatar
xgqdut2016 committed
243
    DEBUG = args.debug
xgqdut2016's avatar
xgqdut2016 committed
244
245
246
247
248
249
250
251
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
        test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

PanZezhongQY's avatar
PanZezhongQY committed
252
    print("\033[92mTest passed!\033[0m")