"vllm/vscode:/vscode.git/clone" did not exist on "89efcf1ce53cd01c27384e3c3e1c6b0761978076"
test_topk_topp_sampler.py 3.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
import torch
5
from flashinfer.sampling import top_k_renorm_probs, top_p_renorm_probs
6
7
from torch import Generator

8
9
10
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
                                                  is_flashinfer_available)
11
12
13
14
15
16

DEVICE = "cuda"

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024

17
18
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available

19

Jee Jee Li's avatar
Jee Jee Li committed
20
21
22
23
24
25
26
27
28
29
30
@pytest.fixture(autouse=True)
def reset_default_device():
    """
    Explicitly set the default device, which can affect subsequent tests. 
    Adding this fixture helps avoid this problem.
    """
    original_device = torch.get_default_device()
    yield
    torch.set_default_device(original_device)


31
32
def test_topk_impl_equivalance():

Jee Jee Li's avatar
Jee Jee Li committed
33
34
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(33)
35

Jee Jee Li's avatar
Jee Jee Li committed
36
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
37

Jee Jee Li's avatar
Jee Jee Li committed
38
39
    # Random top-k values between 1 and 9.
    k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
40

Jee Jee Li's avatar
Jee Jee Li committed
41
42
43
44
    # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
    k.masked_fill_(
        torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
        VOCAB_SIZE)
45

Jee Jee Li's avatar
Jee Jee Li committed
46
47
    # Top-k only implementation
    result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
48

Jee Jee Li's avatar
Jee Jee Li committed
49
50
51
    # Top-p + top-k
    no_op_top_p = torch.tensor([1.0])
    result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
52

Jee Jee Li's avatar
Jee Jee Li committed
53
    assert torch.allclose(result1, result2)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70


def test_flashinfer_sampler():
    '''
    This test verifies that the FlashInfer top-k and top-p sampling
    implementation produces the same results as the Python implementation.

    NOTE: FlashInfer did not directly expose an interface for fused top-k and 
    top-p prob renorm (it did provide fused sampling but we cannot compare 
    sampling results due to randomness), so we will compare the probability
    renormed consequently by top-k and then top-p of FlashInfer implementation.
    '''

    if not FLASHINFER_ENABLED:
        pytest.skip(
            "FlashInfer not installed or not available on this platform.")

Jee Jee Li's avatar
Jee Jee Li committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    torch.set_default_device(DEVICE)
    generator = Generator(device=DEVICE).manual_seed(42)

    # Generate random logits
    logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

    # Generate various top-k and top-p values
    k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
    p_values = torch.rand(
        (BATCH_SIZE, ), generator=generator) * 0.5 + 0.5  # range in [0.5, 1.0]

    # Sometimes disable top-k (k=vocab_size)
    k_values.masked_fill_(
        torch.randint(0,
                      2, (BATCH_SIZE, ),
                      generator=generator,
                      dtype=torch.bool), VOCAB_SIZE)

    # Sometimes disable top-p (p=1.0)
    p_values.masked_fill_(
        torch.randint(0,
                      2, (BATCH_SIZE, ),
                      generator=generator,
                      dtype=torch.bool), 1.0)

    python_logits = apply_top_k_top_p(
        logits=logits.clone(),
        k=k_values,
        p=p_values,
    )
    python_probs = torch.softmax(python_logits, dim=-1)

    # FlashInfer only exposed renorm interfaces for probs so convert first
    flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
    flashinfer_probs = top_k_renorm_probs(
        probs=flashinfer_probs,
        top_k=k_values,
    )
    flashinfer_probs = top_p_renorm_probs(
        probs=flashinfer_probs,
        top_p=p_values,
    )

    # Compare the results
    assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
        "FlashInfer and Python sampling implementations do not match!"