test_speculative_sampling.py 3.77 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn.functional as F
from sgl_kernel import tree_speculative_sampling_target_only


6
7
8
9
def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
    print(
        f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
    )
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    candidates = torch.tensor(
        [
            [0, 1, 2, 3, 4, 5],
            [7, 8, 9, 10, 11, 12],
        ],
        dtype=torch.int32,
        device="cuda",
    )
    retrive_index = torch.tensor(
        [
            [0, 1, 2, 3, 4, 5],
            [6, 7, 8, 9, 10, 11],
        ],
        dtype=torch.int32,
        device="cuda",
    )
    retrive_next_token = torch.tensor(
        [
            [1, 2, -1, 4, 5, -1],
            [4, 2, 3, -1, 5, -1],
        ],
        dtype=torch.int32,
        device="cuda",
    )
    retrive_next_sibling = torch.tensor(
        [
            [-1, 3, -1, -1, -1, -1],
            [-1, -1, -1, -1, 1, -1],
        ],
        dtype=torch.int32,
        device="cuda",
    )

43
    target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    target_logits[0, 0, 3] = 10
    target_logits[0, 3, 4] = 10
    target_logits[0, 4, 5] = 10
    target_logits[1, 0, 11] = 10
    target_logits[1, 4, 12] = 10
    for i in range(target_logits.shape[0]):
        for j in range(target_logits.shape[1]):
            if torch.max(target_logits[i][j]) < 10:
                target_logits[i][j][18] = 10

    temperatures = torch.tensor([0.01, 0.01], dtype=torch.float32, device="cuda")
    predict_shape = (12,)

    bs = candidates.shape[0]
    num_spec_step = 4
    num_draft_tokens = candidates.shape[1]

    predicts = torch.full(
        predict_shape, -1, dtype=torch.int32, device="cuda"
    )  # mutable
    accept_index = torch.full(
        (bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
    )  # mutable
    accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda")  # mutable

    expanded_temperature = temperatures.unsqueeze(1).unsqueeze(1)
    target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
    draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device="cuda")

    coins = torch.rand(bs, num_draft_tokens, device="cuda").to(torch.float32)
    print(f"{candidates=}")
    print(f"{retrive_index=}")
    print(f"{retrive_next_token=}")
    print(f"{retrive_next_sibling=}")
    print(f"{coins=}")

    tree_speculative_sampling_target_only(
        predicts=predicts,
        accept_index=accept_index,
        accept_token_num=accept_token_num,
        candidates=candidates,
        retrive_index=retrive_index,
        retrive_next_token=retrive_next_token,
        retrive_next_sibling=retrive_next_sibling,
        uniform_samples=coins,
        target_probs=target_probs,
        draft_probs=draft_probs,
91
92
        threshold_single=threshold_single,
        threshold_acc=threshold_acc,
93
94
95
96
97
98
99
        deterministic=True,
    )

    print(f"{predicts=}")
    print(f"{accept_index=}")
    print(f"{accept_token_num=}")

100
101
102
103
104
105
106
    return predicts, accept_index, accept_token_num


if __name__ == "__main__":
    predicts, accept_index, accept_token_num = (
        test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
    )
107
108
109
110
111
112
113
    assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
    assert accept_index.tolist() == [
        [0, 3, 4, 5],
        [6, 10, 11, -1],
    ]
    assert accept_token_num.tolist() == [3, 2]

114
115
116
117
118
119
120
121
122
    predicts, accept_index, accept_token_num = (
        test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
    )
    assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
    assert accept_index.tolist() == [
        [0, 1, 2, -1],
        [6, 10, 11, -1],
    ]
    assert accept_token_num.tolist() == [2, 2]