speculative.py 2.73 KB
Newer Older
1
import torch
2
from sgl_kernel.utils import get_cuda_stream
3
4
5
6
7
8
9
10
11
12
13


def tree_speculative_sampling_target_only(
    predicts: torch.Tensor,  # mutable
    accept_index: torch.Tensor,  # mutable
    accept_token_num: torch.Tensor,  # mutable
    candidates: torch.Tensor,
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
    uniform_samples: torch.Tensor,
14
    uniform_samples_for_final_sampling: torch.Tensor,
15
16
    target_probs: torch.Tensor,
    draft_probs: torch.Tensor,
17
18
    threshold_single: float = 1.0,
    threshold_acc: float = 1.0,
19
20
    deterministic: bool = True,
) -> None:
21
    torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
22
23
24
25
26
27
28
29
        predicts,
        accept_index,
        accept_token_num,
        candidates,
        retrive_index,
        retrive_next_token,
        retrive_next_sibling,
        uniform_samples,
30
        uniform_samples_for_final_sampling,
31
32
        target_probs,
        draft_probs,
33
34
        threshold_single,
        threshold_acc,
35
36
37
38
39
        deterministic,
        get_cuda_stream(),
    )


40
41
42
43
44
def verify_tree_greedy(
    predicts: torch.Tensor,  # mutable
    accept_index: torch.Tensor,  # mutable
    accept_token_num: torch.Tensor,  # mutable
    candidates: torch.Tensor,
45
46
47
    retrive_index: torch.Tensor,
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
48
    target_predict: torch.Tensor,
49
) -> None:
50
    torch.ops.sgl_kernel.verify_tree_greedy.default(
51
52
53
54
        predicts,
        accept_index,
        accept_token_num,
        candidates,
55
56
57
        retrive_index,
        retrive_next_token,
        retrive_next_sibling,
58
59
        target_predict,
        get_cuda_stream(),
60
61
62
    )


63
def build_tree_kernel_efficient(
64
65
66
67
68
69
    parent_list: torch.Tensor,
    selected_index: torch.Tensor,
    verified_seq_len: torch.Tensor,
    tree_mask: torch.Tensor,
    positions: torch.Tensor,
    retrive_index: torch.Tensor,
70
71
    retrive_next_token: torch.Tensor,
    retrive_next_sibling: torch.Tensor,
72
73
74
    topk: int,
    depth: int,
    draft_token_num: int,
75
    tree_mask_mode: int,
76
) -> None:
77
    torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
78
79
80
81
82
83
        parent_list,
        selected_index,
        verified_seq_len,
        tree_mask,
        positions,
        retrive_index,
84
85
        retrive_next_token,
        retrive_next_sibling,
86
87
88
        topk,
        depth,
        draft_token_num,
89
        tree_mask_mode,
90
    )
91
92
93
94
95
96
97


def segment_packbits(
    x: torch.Tensor,
    input_indptr: torch.Tensor,
    output_indptr: torch.Tensor,
    y: torch.Tensor,
98
    batch_size: int,
99
) -> None:
100
    torch.ops.sgl_kernel.segment_packbits.default(
101
102
103
104
        x,
        input_indptr,
        output_indptr,
        y,
105
        batch_size,
106
107
        torch.cuda.current_stream().cuda_stream,
    )