test_topk_topp_sampler.py 4.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
import math

import pytest
import torch
7
import torch_xla
8
9

from vllm.platforms import current_platform
10
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
11
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
12
13
14
15
16
17
18
19
20
21

if not current_platform.is_tpu():
    pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6


22
23
24
25
26
27
28
def test_topk_equivalence_to_native_impl():
    with torch.device(xm.xla_device()):
        xm.set_rng_state(seed=33)

        logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))

        # Random top-k values between 1 and 10.
29
        k = torch.randint(1, 10, (BATCH_SIZE,))
30
31

        # Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
32
        k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
33
34
35
36
37
38
39

        result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)

        result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
        assert torch.allclose(result_native, result_tpu)


40
41
42
43
44
45
46
47
def test_topp_result_sums_past_p():
    with torch.device(xm.xla_device()):
        xm.set_rng_state(seed=33)

        logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
        probs = logits.softmax(dim=-1)

        # Random top-p values between 0 and 1.
48
        p = torch.rand((BATCH_SIZE,))
49
50

        # Set p=1 for ~50% of requests in the batch (top-p disabled).
51
        p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
52
53

        no_op_k = torch.tensor([VOCAB_SIZE])
54
        logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
55
56
57
58
59

        # Verify that the masked logit's probability sums to at least p.
        probs.masked_fill_(logits_masked.isinf(), 0)
        masked_prob_sum = probs.sum(dim=-1)

60
        torch_xla.sync()
61
62
63
64
65
66
67

    # Perform assertion on CPU.
    assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))


def test_topp_basic():
    with torch.device(xm.xla_device()):
68
69
70
71
72
73
        logits = torch.tensor(
            [
                [math.log(0.2), math.log(0.3), math.log(0.5)],
                [math.log(0.5), math.log(0.1), math.log(0.4)],
            ]
        )
74

75
76
77
        result = apply_top_k_top_p_tpu(
            logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
        )
78

79
        torch_xla.sync()
80
81
82
83
84
85
86
87
88
89

    # Expect the smallest elements to be dropped.
    expected_result = logits.clone().cpu()
    expected_result[0, 0] = float("-inf")
    expected_result[1, 1] = float("-inf")
    assert torch.allclose(expected_result, result.cpu())


def test_topp_select_all():
    with torch.device(xm.xla_device()):
90
91
92
93
94
95
        logits = torch.tensor(
            [
                [math.log(0.2), math.log(0.3), math.log(0.5)],
                [math.log(0.5), math.log(0.1), math.log(0.4)],
            ]
        )
96

97
98
99
        result = apply_top_k_top_p_tpu(
            logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
        )
100

101
        torch_xla.sync()
102
103
104
105
106
107
108
109

    assert torch.allclose(logits.cpu(), result.cpu())


def test_topp_with_ties():
    with torch.device(xm.xla_device()):
        # Input has multiple math.log(0.3).
        logits = torch.tensor(
110
111
            [[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
        )
112

113
114
115
        result = apply_top_k_top_p_tpu(
            logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
        )
116

117
        torch_xla.sync()
118
119
120
121
122
123
124
125
126
127
128

    # All tie values are included in the top-p set. Tie breaking is left
    # to be done during final sampling (all tie tokens have equal
    # probability of being chosen).
    expected_result = logits.clone().cpu()
    expected_result[0, 3] = float("-inf")
    assert torch.allclose(expected_result, result.cpu())


def test_both_topk_topp():
    with torch.device(xm.xla_device()):
129
130
131
132
133
134
        logits = torch.tensor(
            [
                [math.log(0.2), math.log(0.3), math.log(0.5)],
                [math.log(0.5), math.log(0.1), math.log(0.4)],
            ]
        )
135
136

        # Set k=1 for the first batch.
137
138
139
        result = apply_top_k_top_p_tpu(
            logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
        )
140

141
        torch_xla.sync()
142
143
144
145
146
147
148
149

    # Since for the first batch k=1, expect only the largest element gets
    # selected.
    expected_result = logits.clone().cpu()
    expected_result[0, 0] = float("-inf")
    expected_result[0, 1] = float("-inf")
    expected_result[1, 1] = float("-inf")
    assert torch.allclose(expected_result, result.cpu())