random_sample.py 5.55 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
9
10
    test_operator,
    get_args,
xgqdut2016's avatar
xgqdut2016 committed
11
    debug_all,
xgqdut2016's avatar
xgqdut2016 committed
12
13
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
25
26
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # voc, random_val, topp, topk, temperature
xgqdut2016's avatar
xgqdut2016 committed
27
28
29
30
31
32
33
34
35
36
    (512, 0.8, 0.8, 3, 0.5),
    (4096, 0.05, 0.9, 5, 1.0),
    (16384, 0.15, 0.85, 10, 2.0),
    (512, 0.08, 0, 3, 0.5),
    (4096, 0.5, 0.9, 1, 1.0),
    (16384, 0.15, 0, 1, 2.0),
    (16384, 0.15, 0, 1, 2.0),
    (32000, 0.08, 0.8, 50, 1.0),
    (32000, 0.08, 1.0, 25, 1.0),
    # (119696, 0.01, 1.0, 100, 1.0),
xgqdut2016's avatar
xgqdut2016 committed
37
38
39
]

# Data types used for testing
40
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
xgqdut2016's avatar
xgqdut2016 committed
41
42

_TOLERANCE_MAP = {
43
44
    InfiniDtype.F16: {"atol": 0, "rtol": 0},
    InfiniDtype.BF16: {"atol": 0, "rtol": 0},
xgqdut2016's avatar
xgqdut2016 committed
45
}
xgqdut2016's avatar
xgqdut2016 committed
46

47

xgqdut2016's avatar
xgqdut2016 committed
48
DEBUG = False
xgqdut2016's avatar
xgqdut2016 committed
49
50
51
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
52
53


54
def random_sample(data, random_val, topp, topk, voc, temperature):
55
    if topp > 0 and topk > 1:
56
        sorted_vals, sorted_indices = torch.sort(data, descending=True)
57

58
        scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
thatPepe's avatar
thatPepe committed
59
60
61
62
63
64
65
66
        try:
            probs = torch.softmax(scaled_vals, dim=0)
        except RuntimeError as e:
            if "not implemented for 'Half'" in str(e):
                scaled_vals = scaled_vals.to(torch.float32)
                probs = torch.softmax(scaled_vals, dim=0)
            else:
                raise
67
        cum_probs = torch.cumsum(probs, dim=0)
68

69
70
        k_index = min(topk, voc) - 1
        threshold = min(cum_probs[k_index], topp) * random_val
71

72
73
74
75
76
        try:
            idx = torch.searchsorted(cum_probs, threshold)
        except Exception:
            # Fallback for manual search if torch.searchsorted is not supported
            indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
77
78
79
80
81
            idx = (
                indices[0]
                if indices.numel() > 0
                else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
            )
82
        return sorted_indices[idx]
83

84
    return torch.argmax(data)
zhangyue's avatar
zhangyue committed
85
86


87
88
def test(
    handle,
89
    device,
90
91
92
93
94
    voc,
    random_val,
    topp,
    topk,
    temperature,
95
    dtype=InfiniDtype.F16,
96
    sync=None,
97
):
98
    print(
99
        f"Testing RandomSample on {InfiniDeviceNames[device]} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{InfiniDtypeNames[dtype]}"
100
    )
xgqdut2016's avatar
xgqdut2016 committed
101

PanZezhongQY's avatar
PanZezhongQY committed
102
    _perm = torch.randperm(voc)
103
104
105
    logits = TestTensor.from_torch(
        torch.arange(voc)[_perm].float() * 0.0001, dtype, device
    )
106
107

    ans = random_sample(
108
        logits.torch_tensor(), random_val, topp, topk, voc, temperature
109
110
    ).to(
        torch.int32
111
    )  # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
xgqdut2016's avatar
xgqdut2016 committed
112

113
    indices = TestTensor([], None, InfiniDtype.I32, device, mode="zeros")
PanZezhongQY's avatar
PanZezhongQY committed
114

115
116
117
    if sync is not None:
        sync()

118
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
119
    check_error(
120
        LIBINFINIOP.infiniopCreateRandomSampleDescriptor(
121
122
            handle,
            ctypes.byref(descriptor),
123
124
            indices.descriptor,
            logits.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
125
126
127
128
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
129
130
    for tensor in [logits, indices]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
131
132
133

    workspace_size = c_uint64(0)
    check_error(
134
        LIBINFINIOP.infiniopGetRandomSampleWorkspaceSize(
PanZezhongQY's avatar
PanZezhongQY committed
135
136
137
            descriptor, ctypes.byref(workspace_size)
        )
    )
138
    workspace = TestWorkspace(workspace_size.value, device)
139

xgqdut2016's avatar
xgqdut2016 committed
140
141
    def lib_random_sample():
        check_error(
142
            LIBINFINIOP.infiniopRandomSample(
xgqdut2016's avatar
xgqdut2016 committed
143
                descriptor,
144
                workspace.data(),
xgqdut2016's avatar
xgqdut2016 committed
145
                workspace_size.value,
146
147
                indices.data(),
                logits.data(),
xgqdut2016's avatar
xgqdut2016 committed
148
149
150
151
152
153
                random_val,
                topp,
                topk,
                temperature,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
154
155
        )

xgqdut2016's avatar
xgqdut2016 committed
156
157
    lib_random_sample()

158
159
    if sync is not None:
        sync()
160

xgqdut2016's avatar
xgqdut2016 committed
161
162
163
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug_all(
164
165
            (indices.actual_tensor(), logits.actual_tensor()[indices.actual_tensor()]),
            (ans, logits.torch_tensor()[ans]),
xgqdut2016's avatar
xgqdut2016 committed
166
167
168
169
            "or",
            atol=atol,
            rtol=rtol,
        )
170
171
172
173
    assert (
        indices.actual_tensor() == ans
        or logits.actual_tensor()[indices.actual_tensor()] == logits.torch_tensor()[ans]
    )
xgqdut2016's avatar
xgqdut2016 committed
174

xgqdut2016's avatar
xgqdut2016 committed
175
176
177
    # Profiling workflow
    if PROFILE:
        # fmt: off
178
        profile_operation("PyTorch", lambda: random_sample(
179
180
181
            logits.torch_tensor(), random_val, topp, topk, voc, temperature
        ), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_random_sample(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
182
        # fmt: on
183
    check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
184

185

PanZezhongQY's avatar
PanZezhongQY committed
186
187
188
if __name__ == "__main__":
    args = get_args()

xgqdut2016's avatar
xgqdut2016 committed
189
    DEBUG = args.debug
xgqdut2016's avatar
xgqdut2016 committed
190
191
192
193
194
195
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
196
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
197

PanZezhongQY's avatar
PanZezhongQY committed
198
    print("\033[92mTest passed!\033[0m")