test_topk_topp_sampler.py 3.66 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
5
6
import torch
from torch import Generator

7
from vllm.platforms import current_platform
8
9
10
11
from vllm.v1.sample.ops.topk_topp_sampler import (
    apply_top_k_top_p,
    is_flashinfer_available,
)
12

13
DEVICE = current_platform.device_type
14
15
16
17

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024

18
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available
19
20
if is_flashinfer_available:
    from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
21

22

Jee Jee Li's avatar
Jee Jee Li committed
23
24
25
@pytest.fixture(autouse=True)
def reset_default_device():
    """
26
    Explicitly set the default device, which can affect subsequent tests.
Jee Jee Li's avatar
Jee Jee Li committed
27
28
29
30
31
32
33
    Adding this fixture helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)


34
def test_topk_impl_equivalence():
Jee Jee Li's avatar
Jee Jee Li committed
35
36
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(33)
37

Jee Jee Li's avatar
Jee Jee Li committed
38
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
39

Jee Jee Li's avatar
Jee Jee Li committed
40
    # Random top-k values between 1 and 9.
41
    k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
42

Jee Jee Li's avatar
Jee Jee Li committed
43
44
    # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
    k.masked_fill_(
45
46
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
    )
47

Jee Jee Li's avatar
Jee Jee Li committed
48
49
    # Top-k only implementation
    result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
50

Jee Jee Li's avatar
Jee Jee Li committed
51
52
53
    # Top-p + top-k
    no_op_top_p = torch.tensor([1.0])
    result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
54

Jee Jee Li's avatar
Jee Jee Li committed
55
    assert torch.allclose(result1, result2)
56
57
58


def test_flashinfer_sampler():
59
    """
60
61
62
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

63
64
    NOTE: FlashInfer did not directly expose an interface for fused top-k and
    top-p prob renorm (it did provide fused sampling but we cannot compare
65
66
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
67
    """
68
69

    if not FLASHINFER_ENABLED:
70
        pytest.skip("FlashInfer not installed or not available on this platform.")
71

Jee Jee Li's avatar
Jee Jee Li committed
72
73
74
75
76
77
78
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(42)

    # Generate random logits
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

    # Generate various top-k and top-p values
79
80
81
82
    k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
    p_values = (
        torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
    )  # range in [0.5, 1.0]
Jee Jee Li's avatar
Jee Jee Li committed
83
84
85

    # Sometimes disable top-k (k=vocab_size)
    k_values.masked_fill_(
86
87
88
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
        VOCAB_SIZE,
    )
Jee Jee Li's avatar
Jee Jee Li committed
89
90
91

    # Sometimes disable top-p (p=1.0)
    p_values.masked_fill_(
92
93
        torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
    )
Jee Jee Li's avatar
Jee Jee Li committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

    python_logits = apply_top_k_top_p(
        logits=logits.clone(),
        k=k_values,
        p=p_values,
    )
    python_probs = torch.softmax(python_logits, dim=-1)

    # FlashInfer only exposed renorm interfaces for probs so convert first
    flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
    flashinfer_probs = top_k_renorm_probs(
        probs=flashinfer_probs,
        top_k=k_values,
    )
    flashinfer_probs = top_p_renorm_probs(
        probs=flashinfer_probs,
        top_p=p_values,
    )

    # Compare the results
114
    assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
Jee Jee Li's avatar
Jee Jee Li committed
115
        "FlashInfer and Python sampling implementations do not match!"
116
    )