random_sample.py 5.65 KB
Newer Older
xgqdut2016's avatar
xgqdut2016 committed
1
import torch
PanZezhongQY's avatar
PanZezhongQY committed
2
import ctypes
3
from ctypes import c_uint64
xgqdut2016's avatar
xgqdut2016 committed
4
from libinfiniop import (
5
6
    LIBINFINIOP,
    TestTensor,
xgqdut2016's avatar
xgqdut2016 committed
7
    get_test_devices,
PanZezhongQY's avatar
PanZezhongQY committed
8
    check_error,
xgqdut2016's avatar
xgqdut2016 committed
9
10
    test_operator,
    get_args,
xgqdut2016's avatar
xgqdut2016 committed
11
    debug_all,
xgqdut2016's avatar
xgqdut2016 committed
12
13
    get_tolerance,
    profile_operation,
14
15
16
17
18
    TestWorkspace,
    InfiniDtype,
    InfiniDtypeNames,
    InfiniDeviceNames,
    infiniopOperatorDescriptor_t,
PanZezhongQY's avatar
PanZezhongQY committed
19
20
)

xgqdut2016's avatar
xgqdut2016 committed
21
22
23
24
25
26
# ==============================================================================
#  Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
    # voc, random_val, topp, topk, temperature
xgqdut2016's avatar
xgqdut2016 committed
27
28
29
30
31
32
33
34
35
36
    (512, 0.8, 0.8, 3, 0.5),
    (4096, 0.05, 0.9, 5, 1.0),
    (16384, 0.15, 0.85, 10, 2.0),
    (512, 0.08, 0, 3, 0.5),
    (4096, 0.5, 0.9, 1, 1.0),
    (16384, 0.15, 0, 1, 2.0),
    (16384, 0.15, 0, 1, 2.0),
    (32000, 0.08, 0.8, 50, 1.0),
    (32000, 0.08, 1.0, 25, 1.0),
    # (119696, 0.01, 1.0, 100, 1.0),
xgqdut2016's avatar
xgqdut2016 committed
37
38
39
]

# Data types used for testing
40
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
xgqdut2016's avatar
xgqdut2016 committed
41
42

_TOLERANCE_MAP = {
43
44
    InfiniDtype.F16: {"atol": 0, "rtol": 0},
    InfiniDtype.BF16: {"atol": 0, "rtol": 0},
xgqdut2016's avatar
xgqdut2016 committed
45
}
xgqdut2016's avatar
xgqdut2016 committed
46

47

xgqdut2016's avatar
xgqdut2016 committed
48
DEBUG = False
xgqdut2016's avatar
xgqdut2016 committed
49
50
51
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
PanZezhongQY's avatar
PanZezhongQY committed
52
53


54
def random_sample(data, random_val, topp, topk, voc, temperature):
55
    if topp > 0 and topk > 1:
56
        sorted_vals, sorted_indices = torch.sort(data, descending=True)
57
58
        print(sorted_vals[:topk])
        print(sorted_indices[:topk])
59
        scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
thatPepe's avatar
thatPepe committed
60
61
62
63
64
65
66
67
        try:
            probs = torch.softmax(scaled_vals, dim=0)
        except RuntimeError as e:
            if "not implemented for 'Half'" in str(e):
                scaled_vals = scaled_vals.to(torch.float32)
                probs = torch.softmax(scaled_vals, dim=0)
            else:
                raise
68
        cum_probs = torch.cumsum(probs, dim=0)
69

70
71
        k_index = min(topk, voc) - 1
        threshold = min(cum_probs[k_index], topp) * random_val
72

73
74
75
76
77
        try:
            idx = torch.searchsorted(cum_probs, threshold)
        except Exception:
            # Fallback for manual search if torch.searchsorted is not supported
            indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
78
79
80
81
82
            idx = (
                indices[0]
                if indices.numel() > 0
                else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
            )
83
        return sorted_indices[idx]
84

85
    return torch.argmax(data)
zhangyue's avatar
zhangyue committed
86
87


88
89
def test(
    handle,
90
    device,
91
92
93
94
95
    voc,
    random_val,
    topp,
    topk,
    temperature,
96
    dtype=InfiniDtype.F16,
97
    sync=None,
98
):
99
    print(
100
        f"Testing RandomSample on {InfiniDeviceNames[device]} with voc:{voc} random_val:{random_val} topp:{topp} topk:{topk} temperature:{temperature} dtype:{InfiniDtypeNames[dtype]}"
101
    )
xgqdut2016's avatar
xgqdut2016 committed
102

PanZezhongQY's avatar
PanZezhongQY committed
103
    _perm = torch.randperm(voc)
104
105
106
    logits = TestTensor.from_torch(
        torch.arange(voc)[_perm].float() * 0.0001, dtype, device
    )
107
108

    ans = random_sample(
109
        logits.torch_tensor(), random_val, topp, topk, voc, temperature
110
111
    ).to(
        torch.int32
112
    )  # 这个函数在device速度可能会很慢,可以通过data.to("cpu")方式加快计算过程
xgqdut2016's avatar
xgqdut2016 committed
113

114
    indices = TestTensor([], None, InfiniDtype.I32, device, mode="zeros")
PanZezhongQY's avatar
PanZezhongQY committed
115

116
117
118
    if sync is not None:
        sync()

119
    descriptor = infiniopOperatorDescriptor_t()
PanZezhongQY's avatar
PanZezhongQY committed
120
    check_error(
121
        LIBINFINIOP.infiniopCreateRandomSampleDescriptor(
122
123
            handle,
            ctypes.byref(descriptor),
124
125
            indices.descriptor,
            logits.descriptor,
PanZezhongQY's avatar
PanZezhongQY committed
126
127
128
129
        )
    )

    # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
130
131
    for tensor in [logits, indices]:
        tensor.destroy_desc()
PanZezhongQY's avatar
PanZezhongQY committed
132
133
134

    workspace_size = c_uint64(0)
    check_error(
135
        LIBINFINIOP.infiniopGetRandomSampleWorkspaceSize(
PanZezhongQY's avatar
PanZezhongQY committed
136
137
138
            descriptor, ctypes.byref(workspace_size)
        )
    )
139
    workspace = TestWorkspace(workspace_size.value, device)
140

xgqdut2016's avatar
xgqdut2016 committed
141
142
    def lib_random_sample():
        check_error(
143
            LIBINFINIOP.infiniopRandomSample(
xgqdut2016's avatar
xgqdut2016 committed
144
                descriptor,
145
                workspace.data(),
xgqdut2016's avatar
xgqdut2016 committed
146
                workspace_size.value,
147
148
                indices.data(),
                logits.data(),
xgqdut2016's avatar
xgqdut2016 committed
149
150
151
152
153
154
                random_val,
                topp,
                topk,
                temperature,
                None,
            )
PanZezhongQY's avatar
PanZezhongQY committed
155
156
        )

xgqdut2016's avatar
xgqdut2016 committed
157
158
    lib_random_sample()

159
160
    if sync is not None:
        sync()
161
    print(indices.actual_tensor(), ans)
xgqdut2016's avatar
xgqdut2016 committed
162
163
164
    atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
    if DEBUG:
        debug_all(
165
166
            (indices.actual_tensor(), logits.actual_tensor()[indices.actual_tensor()]),
            (ans, logits.torch_tensor()[ans]),
xgqdut2016's avatar
xgqdut2016 committed
167
168
169
170
            "or",
            atol=atol,
            rtol=rtol,
        )
171
172
173
174
    assert (
        indices.actual_tensor() == ans
        or logits.actual_tensor()[indices.actual_tensor()] == logits.torch_tensor()[ans]
    )
xgqdut2016's avatar
xgqdut2016 committed
175

xgqdut2016's avatar
xgqdut2016 committed
176
177
178
    # Profiling workflow
    if PROFILE:
        # fmt: off
179
        profile_operation("PyTorch", lambda: random_sample(
180
181
182
            logits.torch_tensor(), random_val, topp, topk, voc, temperature
        ), device, NUM_PRERUN, NUM_ITERATIONS)
        profile_operation("    lib", lambda: lib_random_sample(), device, NUM_PRERUN, NUM_ITERATIONS)
xgqdut2016's avatar
xgqdut2016 committed
183
        # fmt: on
184
    check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor))
PanZezhongQY's avatar
PanZezhongQY committed
185

186

PanZezhongQY's avatar
PanZezhongQY committed
187
188
189
if __name__ == "__main__":
    args = get_args()

xgqdut2016's avatar
xgqdut2016 committed
190
    DEBUG = args.debug
xgqdut2016's avatar
xgqdut2016 committed
191
192
193
194
195
196
    PROFILE = args.profile
    NUM_PRERUN = args.num_prerun
    NUM_ITERATIONS = args.num_iterations

    # Execute tests
    for device in get_test_devices(args):
197
        test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
198

PanZezhongQY's avatar
PanZezhongQY committed
199
    print("\033[92mTest passed!\033[0m")