test_topk_topp_sampler.py 3.81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
5
6
import torch
from torch import Generator

7
8
9
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
                                                  is_flashinfer_available)
10

11
DEVICE = current_platform.device_type
12
13
14
15

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024

16
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
17
18
if is_flashinfer_available:
    from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
19

20

Jee Jee Li's avatar
Jee Jee Li committed
21
22
23
@pytest.fixture(autouse=True)
def reset_default_device():
    """
24
    Explicitly set the default device, which can affect subsequent tests.
Jee Jee Li's avatar
Jee Jee Li committed
25
26
27
28
29
30
31
    Adding this fixture helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)


32
def test_topk_impl_equivalence():
33

Jee Jee Li's avatar
Jee Jee Li committed
34
35
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(33)
36

Jee Jee Li's avatar
Jee Jee Li committed
37
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
38

Jee Jee Li's avatar
Jee Jee Li committed
39
40
    # Random top-k values between 1 and 9.
    k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
41

Jee Jee Li's avatar
Jee Jee Li committed
42
43
44
45
    # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
    k.masked_fill_(
        torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
        VOCAB_SIZE)
46

Jee Jee Li's avatar
Jee Jee Li committed
47
48
    # Top-k only implementation
    result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
49

Jee Jee Li's avatar
Jee Jee Li committed
50
51
52
    # Top-p + top-k
    no_op_top_p = torch.tensor([1.0])
    result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
53

Jee Jee Li's avatar
Jee Jee Li committed
54
    assert torch.allclose(result1, result2)
55
56
57
58
59
60
61


def test_flashinfer_sampler():
    '''
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

62
63
    NOTE: FlashInfer did not directly expose an interface for fused top-k and
    top-p prob renorm (it did provide fused sampling but we cannot compare
64
65
66
67
68
69
70
71
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
    '''

    if not FLASHINFER_ENABLED:
        pytest.skip(
            "FlashInfer not installed or not available on this platform.")

Jee Jee Li's avatar
Jee Jee Li committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(42)

    # Generate random logits
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

    # Generate various top-k and top-p values
    k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
    p_values = torch.rand(
        (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5  # range in [0.5, 1.0]

    # Sometimes disable top-k (k=vocab_size)
    k_values.masked_fill_(
        torch.randint(0,
                      2, (BATCH_SIZE, ),
                      generator=generator,
                      dtype=torch.bool), VOCAB_SIZE)

    # Sometimes disable top-p (p=1.0)
    p_values.masked_fill_(
        torch.randint(0,
                      2, (BATCH_SIZE, ),
                      generator=generator,
                      dtype=torch.bool), 1.0)

    python_logits = apply_top_k_top_p(
        logits=logits.clone(),
        k=k_values,
        p=p_values,
    )
    python_probs = torch.softmax(python_logits, dim=-1)

    # FlashInfer only exposed renorm interfaces for probs so convert first
    flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
    flashinfer_probs = top_k_renorm_probs(
        probs=flashinfer_probs,
        top_k=k_values,
    )
    flashinfer_probs = top_p_renorm_probs(
        probs=flashinfer_probs,
        top_p=p_values,
    )

    # Compare the results
    assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
        "FlashInfer and Python sampling implementations do not match!"