bench_top_k_top_p_sampling.py 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import itertools

import sgl_kernel
import torch
import triton
import triton.testing


def torch_top_k_top_p_joint_sampling_from_probs(
    normalized_prob, top_k, top_p, eps=1e-4
):
    """Reference PyTorch implementation of joint top-k top-p sampling."""
    batch_size, vocab_size = normalized_prob.shape
    samples = torch.empty(batch_size, dtype=torch.int64, device=normalized_prob.device)

    for i in range(batch_size):
        p_val = top_p[i].item()
        k_val = top_k[i].item()

        # top-p mask
        sorted_prob, indices = torch.sort(normalized_prob[i], descending=False)
        cdf = torch.cumsum(sorted_prob, dim=-1)
        mask_top_p = torch.zeros(
            vocab_size, dtype=torch.int32, device=normalized_prob.device
        )
        mask_top_p.scatter_add_(0, indices, (cdf > (1 - p_val) - eps).int())

        # top-k mask
        sorted_prob_desc, _ = torch.sort(normalized_prob[i], descending=True)
        pivot = sorted_prob_desc[k_val - 1]
        mask_top_k = (normalized_prob[i] >= pivot).int()

        # joint mask
        mask = torch.minimum(mask_top_p, mask_top_k).bool()

        # sample from masked probs
        masked_probs = normalized_prob[i] * mask
        masked_probs = masked_probs / masked_probs.sum()
        idx = torch.multinomial(masked_probs, 1)
        samples[i] = idx

    return samples


def calculate_diff(batch_size, vocab_size, p):
    """Compare Torch reference and SGLang kernel for correctness."""
    torch.manual_seed(42)
    if p == 0.1:
        k = int(vocab_size * 0.5)
    elif p == 0.5:
        k = int(vocab_size * 0.1)
    else:
        raise ValueError("p not recognized")

    device = torch.device("cuda")
    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)

    top_p_tensor = torch.full((batch_size,), p, device=device)
    top_k_tensor = torch.full((batch_size,), k, device=device)

    torch_samples = torch_top_k_top_p_joint_sampling_from_probs(
        normalized_prob, top_k_tensor, top_p_tensor
    )
    sglang_samples = sgl_kernel.top_k_top_p_sampling_from_probs(
        normalized_prob, top_k_tensor, top_p_tensor, filter_apply_order="joint"
    )


# parameter space
batch_size_range = [16, 64, 128]
vocab_size_range = [111, 32000]
p_range = [0.1, 0.5]
configs = list(itertools.product(batch_size_range, vocab_size_range, p_range))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "vocab_size", "p"],
        x_vals=configs,
        line_arg="provider",
        line_vals=["torch", "sglang"],
        line_names=["Torch Reference", "SGL Kernel"],
        styles=[("red", "-"), ("green", "-")],
        ylabel="us",
        plot_name="top-k-top-p-joint-sampling-performance",
        args={},
    )
)
def benchmark_sampling(batch_size, vocab_size, p, provider):
    torch.manual_seed(42)
    if p == 0.1:
        k = int(vocab_size * 0.5)
    elif p == 0.5:
        k = int(vocab_size * 0.1)
    else:
        raise ValueError("p not recognized")

    device = torch.device("cuda")
    pre_norm_prob = torch.rand(batch_size, vocab_size, device=device)
    normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
    top_p_tensor = torch.full((batch_size,), p, device=device)
    top_k_tensor = torch.full((batch_size,), k, device=device)

    if provider == "torch":
        fn = lambda: torch_top_k_top_p_joint_sampling_from_probs(
            normalized_prob.clone(), top_k_tensor, top_p_tensor
        )
    elif provider == "sglang":
        fn = lambda: sgl_kernel.top_k_top_p_sampling_from_probs(
            normalized_prob.clone(),
            top_k_tensor,
            top_p_tensor,
            filter_apply_order="joint",
        )

    ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
    # Correctness check
    for cfg in configs:
        calculate_diff(*cfg)

    print("\n" + "=" * 60)
    print("Starting performance benchmark...")
    benchmark_sampling.run(print_data=True)