test_topk_topp_sampler.py 3.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
import pytest
3
import torch
4
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
5
6
from torch import Generator

7
8
9
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
                                                  is_flashinfer_available)
10
11
12
13
14
15

DEVICE = "cuda"

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024

16
17
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

def test_topk_impl_equivalance():

    with torch.device(DEVICE):
        generator = Generator(device=DEVICE).manual_seed(33)

        logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

        # Random top-k values between 1 and 9.
        k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)

        # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
        k.masked_fill_(
            torch.randint(0,
                          2, (BATCH_SIZE, ),
                          generator=generator,
                          dtype=bool), VOCAB_SIZE)

        # Top-k only implementation
        result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)

        # Top-p + top-k
        no_op_top_p = torch.tensor([1.0])
        result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)

        assert torch.allclose(result1, result2)
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107


def test_flashinfer_sampler():
    '''
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

    NOTE: FlashInfer did not directly expose an interface for fused top-k and 
    top-p prob renorm (it did provide fused sampling but we cannot compare 
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
    '''

    if not FLASHINFER_ENABLED:
        pytest.skip(
            "FlashInfer not installed or not available on this platform.")

    with torch.device(DEVICE):
        generator = Generator(device=DEVICE).manual_seed(42)

        # Generate random logits
        logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

        # Generate various top-k and top-p values
        k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
        p_values = torch.rand(
            (BATCH_SIZE, ),
            generator=generator) * 0.5 + 0.5  # range in [0.5, 1.0]

        # Sometimes disable top-k (k=vocab_size)
        k_values.masked_fill_(
            torch.randint(0,
                          2, (BATCH_SIZE, ),
                          generator=generator,
                          dtype=torch.bool), VOCAB_SIZE)

        # Sometimes disable top-p (p=1.0)
        p_values.masked_fill_(
            torch.randint(0,
                          2, (BATCH_SIZE, ),
                          generator=generator,
                          dtype=torch.bool), 1.0)

        python_logits = apply_top_k_top_p(
            logits=logits.clone(),
            k=k_values,
            p=p_values,
        )
        python_probs = torch.softmax(python_logits, dim=-1)

        # FlashInfer only exposed renorm interfaces for probs so convert first
        flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
        flashinfer_probs = top_k_renorm_probs(
            probs=flashinfer_probs,
            top_k=k_values,
        )
        flashinfer_probs = top_p_renorm_probs(
            probs=flashinfer_probs,
            top_p=p_values,
        )

        # Compare the results
        assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
            "FlashInfer and Python sampling implementations do not match!"