"vscode:/vscode.git/clone" did not exist on "07dd6f8c0e267662f62c39cd8334c2b5d157ab39"
speculative.py 3.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch


def tree_speculative_sampling_target_only(
    predicts: torch.Tensor,  # mutable
    accept_index: torch.Tensor,  # mutable
    accept_token_num: torch.Tensor,  # mutable
    candidates: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    uniform_samples: torch.Tensor,
13
    uniform_samples_for_final_sampling: torch.Tensor,
14
15
    target_probs: torch.Tensor,
    draft_probs: torch.Tensor,
16
17
    threshold_single: float = 1.0,
    threshold_acc: float = 1.0,
18
19
    deterministic: bool = True,
) -> None:
20
    torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
21
22
23
24
25
26
27
28
        predicts,
        accept_index,
        accept_token_num,
        candidates,
        retrive_index,
        retrive_next_token,
        retrive_next_sibling,
        uniform_samples,
29
        uniform_samples_for_final_sampling,
30
31
        target_probs,
        draft_probs,
32
33
        threshold_single,
        threshold_acc,
34
35
36
37
        deterministic,
    )


38
39
40
41
42
def verify_tree_greedy(
    predicts: torch.Tensor,  # mutable
    accept_index: torch.Tensor,  # mutable
    accept_token_num: torch.Tensor,  # mutable
    candidates: torch.Tensor,
43
44
45
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
46
    target_predict: torch.Tensor,
47
) -> None:
48
    torch.ops.sgl_kernel.verify_tree_greedy.default(
49
50
51
52
        predicts,
        accept_index,
        accept_token_num,
        candidates,
53
54
55
        retrive_index,
        retrive_next_token,
        retrive_next_sibling,
56
        target_predict,
57
58
59
    )


60
def build_tree_kernel_efficient(
61
62
63
64
65
66
    parent_list: torch.Tensor,
    selected_index: torch.Tensor,
    verified_seq_len: torch.Tensor,
    tree_mask: torch.Tensor,
    positions: torch.Tensor,
    retrive_index: torch.Tensor,
67
68
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
69
70
71
    topk: int,
    depth: int,
    draft_token_num: int,
72
    tree_mask_mode: int,
73
) -> None:
74
    torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
75
76
77
78
79
80
        parent_list,
        selected_index,
        verified_seq_len,
        tree_mask,
        positions,
        retrive_index,
81
82
        retrive_next_token,
        retrive_next_sibling,
83
84
85
        topk,
        depth,
        draft_token_num,
86
        tree_mask_mode,
87
    )
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def reconstruct_indices_from_tree_mask(
    tree_mask: torch.Tensor,
    verified_seq_len: torch.Tensor,
    positions: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    batch_size: int,
    draft_token_num: int,
) -> None:
    torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default(
        tree_mask,
        verified_seq_len,
        positions,
        retrive_index,
        retrive_next_token,
        retrive_next_sibling,
        batch_size,
        draft_token_num,
    )


112
113
114
115
116
def segment_packbits(
    x: torch.Tensor,
    input_indptr: torch.Tensor,
    output_indptr: torch.Tensor,
    y: torch.Tensor,
117
    batch_size: int,
118
) -> None:
119
    torch.ops.sgl_kernel.segment_packbits.default(
120
121
122
123
        x,
        input_indptr,
        output_indptr,
        y,
124
        batch_size,
125
126
        torch.cuda.current_stream().cuda_stream,
    )