test_topk_topp_sampler.py 3.92 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
5
6
import torch
from torch import Generator

7
from vllm.platforms import current_platform
8
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
9

10
DEVICE = current_platform.device_type
11
12
13
14
15

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024


Jee Jee Li's avatar
Jee Jee Li committed
16
17
18
@pytest.fixture(autouse=True)
def reset_default_device():
    """
19
    Explicitly set the default device, which can affect subsequent tests.
Jee Jee Li's avatar
Jee Jee Li committed
20
21
22
23
24
25
26
    Adding this fixture helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)


27
def test_topk_impl_equivalence():
Jee Jee Li's avatar
Jee Jee Li committed
28
29
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(33)
30

Jee Jee Li's avatar
Jee Jee Li committed
31
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
32

Jee Jee Li's avatar
Jee Jee Li committed
33
    # Random top-k values between 1 and 9.
34
    k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
35

Jee Jee Li's avatar
Jee Jee Li committed
36
37
    # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
    k.masked_fill_(
38
39
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
    )
40

Jee Jee Li's avatar
Jee Jee Li committed
41
42
    # Top-k only implementation
    result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
43

Jee Jee Li's avatar
Jee Jee Li committed
44
45
46
    # Top-p + top-k
    no_op_top_p = torch.tensor([1.0])
    result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
47

Jee Jee Li's avatar
Jee Jee Li committed
48
    assert torch.allclose(result1, result2)
49
50


51
52
53
54
55
@pytest.mark.skip(
    reason="FlashInfer top-k/top-p renorm comparison fails; "
    "needs investigation of tolerance threshold or "
    "interface differences between Python and FlashInfer implementations"
)
56
def test_flashinfer_sampler():
57
    """
58
59
60
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

61
62
    NOTE: FlashInfer did not directly expose an interface for fused top-k and
    top-p prob renorm (it did provide fused sampling but we cannot compare
63
64
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
65
    """
66
67
68
69
70
71
72
73
    try:
        from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs

        is_flashinfer_available = True
    except ImportError:
        is_flashinfer_available = False

    FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
74
75

    if not FLASHINFER_ENABLED:
76
        pytest.skip("FlashInfer not installed or not available on this platform.")
77

Jee Jee Li's avatar
Jee Jee Li committed
78
79
80
81
82
83
84
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(42)

    # Generate random logits
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

    # Generate various top-k and top-p values
85
86
87
88
    k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
    p_values = (
        torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
    )  # range in [0.5, 1.0]
Jee Jee Li's avatar
Jee Jee Li committed
89
90
91

    # Sometimes disable top-k (k=vocab_size)
    k_values.masked_fill_(
92
93
94
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
        VOCAB_SIZE,
    )
Jee Jee Li's avatar
Jee Jee Li committed
95
96
97

    # Sometimes disable top-p (p=1.0)
    p_values.masked_fill_(
98
99
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
    )
Jee Jee Li's avatar
Jee Jee Li committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    python_logits = apply_top_k_top_p(
        logits=logits.clone(),
        k=k_values,
        p=p_values,
    )
    python_probs = torch.softmax(python_logits, dim=-1)

    # FlashInfer only exposed renorm interfaces for probs so convert first
    flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
    flashinfer_probs = top_k_renorm_probs(
        probs=flashinfer_probs,
        top_k=k_values,
    )
    flashinfer_probs = top_p_renorm_probs(
        probs=flashinfer_probs,
        top_p=p_values,
    )

    # Compare the results
120
    assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
Jee Jee Li's avatar
Jee Jee Li committed
121
        "FlashInfer and Python sampling implementations do not match!"
122
    )