test_topk_topp_sampler.py 3.71 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
5
6
import torch
from torch import Generator

7
from vllm.platforms import current_platform
8
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
9

10
DEVICE = current_platform.device_type
11
12
13
14
15

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024


Jee Jee Li's avatar
Jee Jee Li committed
16
17
18
@pytest.fixture(autouse=True)
def reset_default_device():
    """
19
    Explicitly set the default device, which can affect subsequent tests.
Jee Jee Li's avatar
Jee Jee Li committed
20
21
22
23
24
25
26
    Adding this fixture helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)


27
def test_topk_impl_equivalence():
Jee Jee Li's avatar
Jee Jee Li committed
28
29
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(33)
30

Jee Jee Li's avatar
Jee Jee Li committed
31
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
32

Jee Jee Li's avatar
Jee Jee Li committed
33
    # Random top-k values between 1 and 9.
34
    k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
35

Jee Jee Li's avatar
Jee Jee Li committed
36
37
    # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
    k.masked_fill_(
38
39
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
    )
40

Jee Jee Li's avatar
Jee Jee Li committed
41
42
    # Top-k only implementation
    result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
43

Jee Jee Li's avatar
Jee Jee Li committed
44
45
46
    # Top-p + top-k
    no_op_top_p = torch.tensor([1.0])
    result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
47

Jee Jee Li's avatar
Jee Jee Li committed
48
    assert torch.allclose(result1, result2)
49
50
51


def test_flashinfer_sampler():
52
    """
53
54
55
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

56
57
    NOTE: FlashInfer did not directly expose an interface for fused top-k and
    top-p prob renorm (it did provide fused sampling but we cannot compare
58
59
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
60
    """
61
62
63
64
65
66
67
68
    try:
        from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs

        is_flashinfer_available = True
    except ImportError:
        is_flashinfer_available = False

    FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
69
70

    if not FLASHINFER_ENABLED:
71
        pytest.skip("FlashInfer not installed or not available on this platform.")
72

Jee Jee Li's avatar
Jee Jee Li committed
73
74
75
76
77
78
79
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(42)

    # Generate random logits
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

    # Generate various top-k and top-p values
80
81
82
83
    k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
    p_values = (
        torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
    )  # range in [0.5, 1.0]
Jee Jee Li's avatar
Jee Jee Li committed
84
85
86

    # Sometimes disable top-k (k=vocab_size)
    k_values.masked_fill_(
87
88
89
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
        VOCAB_SIZE,
    )
Jee Jee Li's avatar
Jee Jee Li committed
90
91
92

    # Sometimes disable top-p (p=1.0)
    p_values.masked_fill_(
93
94
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
    )
Jee Jee Li's avatar
Jee Jee Li committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

    python_logits = apply_top_k_top_p(
        logits=logits.clone(),
        k=k_values,
        p=p_values,
    )
    python_probs = torch.softmax(python_logits, dim=-1)

    # FlashInfer only exposed renorm interfaces for probs so convert first
    flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
    flashinfer_probs = top_k_renorm_probs(
        probs=flashinfer_probs,
        top_k=k_values,
    )
    flashinfer_probs = top_p_renorm_probs(
        probs=flashinfer_probs,
        top_p=p_values,
    )

    # Compare the results
115
    assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
Jee Jee Li's avatar
Jee Jee Li committed
116
        "FlashInfer and Python sampling implementations do not match!"
117
    )