bench_top_k_top_p_sampling.py 4.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import itertools

import sgl_kernel
import torch
import triton
import triton.testing


def torch_top_k_top_p_joint_sampling_from_probs(
    normalized_prob, top_k, top_p, eps=1e-4
):
    """Reference PyTorch implementation of joint top-k top-p sampling."""
    batch_size, vocab_size = normalized_prob.shape
    samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)

    for i in range(batch_size):
        p_val = top_p[i].item()
        k_val = top_k[i].item()

        # top-p mask
        sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
        cdf = torch.cumsum(sorted_prob, dim=-1)
        mask_top_p = torch.zeros(
            vocab_size, dtype=torch.int32, device=normalized_prob.device
        )
        mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())

        # top-k mask
        sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
        pivot = sorted_prob_desc[k_val - 1]
        mask_top_k = (normalized_prob[i] >= pivot).int()

        # joint mask
        mask = torch.minimum(mask_top_p, mask_top_k).bool()

        # sample from masked probs
        masked_probs = normalized_prob[i] * mask
        masked_probs = masked_probs / masked_probs.sum()
        idx = torch.multinomial(masked_probs, 1)
        samples[i] = idx

    return samples


def calculate_diff(batch_size, vocab_size, p):
    """Compare Torch reference and SGLang kernel for correctness."""
    torch.manual_seed(42)
    if p == 0.1:
        k = int(vocab_size * 0.5)
    elif p == 0.5:
        k = int(vocab_size * 0.1)
    else:
        raise ValueError("p not recognized")

    device = torch.device("cuda")
    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)

    top_p_tensor = torch.full((batch_size,), p, device=device)
    top_k_tensor = torch.full((batch_size,), k, device=device)

    torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
        normalized_prob, top_k_tensor, top_p_tensor
    )
    sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
        normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
    )


# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "vocab_size", "p"],
        x_vals=configs,
        line_arg="provider",
        line_vals=["torch", "sglang"],
        line_names=["Torch Reference", "SGL Kernel"],
        styles=[("red", "-"), ("green", "-")],
        ylabel="us",
        plot_name="top-k-top-p-joint-sampling-performance",
        args={},
    )
)
def benchmark_sampling(batch_size, vocab_size, p, provider):
    torch.manual_seed(42)
    if p == 0.1:
        k = int(vocab_size * 0.5)
    elif p == 0.5:
        k = int(vocab_size * 0.1)
    else:
        raise ValueError("p not recognized")

    device = torch.device("cuda")
    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
    top_p_tensor = torch.full((batch_size,), p, device=device)
    top_k_tensor = torch.full((batch_size,), k, device=device)

    if provider == "torch":
        fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
            normalized_prob.clone(), top_k_tensor, top_p_tensor
        )
    elif provider == "sglang":
        fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
            normalized_prob.clone(),
            top_k_tensor,
            top_p_tensor,
            filter_apply_order="joint",
        )

117
118
119
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
        fn, quantiles=[0.5, 0.2, 0.8]
    )
120
121
122
123
124
125
126
127
128
129
130
    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
    # Correctness check
    for cfg in configs:
        calculate_diff(*cfg)

    print("\n" + "=" * 60)
    print("Starting performance benchmark...")
    benchmark_sampling.run(print_data=True)