test_frequency_penalty.py 3.29 KB
Newer Older
1
import unittest
2
from typing import List
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import torch

from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
    BatchedFrequencyPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
    BaseBatchedPenalizerTest,
    MockSamplingParams,
    Step,
    StepType,
    Subject,
)


18
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
19
    Penalizer = BatchedFrequencyPenalizer
20
21
22
23
24
25
26
    frequency_penalty: float

    def setUp(self):
        if self.__class__ == BaseBatchedFrequencyPenalizerTest:
            self.skipTest("Base class for frequency_penalty tests")

        super().setUp()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

    def _create_subject(self, frequency_penalty: float) -> Subject:
        return Subject(
            sampling_params=MockSamplingParams(
                frequency_penalty=frequency_penalty,
            ),
            steps=[
                Step(
                    type=StepType.INPUT,
                    token_ids=[0, 1, 2],
                    expected_tensors={
                        "frequency_penalties": self.tensor(
                            [[frequency_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_frequency_penalties": self.tensor(
                            [[0.0] * self.vocab_size], dtype=torch.float32
                        ),
                    },
                    expected_logits=self.tensor(
                        [[1] * self.vocab_size], dtype=torch.float32
                    ),
                ),
                Step(
                    type=StepType.OUTPUT,
51
52
53
54
55
                    token_ids=[
                        1,
                        2,
                        2,
                    ],  # This is the output ids of one request in three steps.
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                    expected_tensors={
                        "frequency_penalties": self.tensor(
                            [[frequency_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_frequency_penalties": self.tensor(
                            [
                                [
                                    frequency_penalty * i if i in {1, 2} else 0.0
                                    for i in range(self.vocab_size)
                                ],
                            ],
                            dtype=torch.float32,
                        ),
                    },
                    expected_logits=self.tensor(
                        [
                            [
                                1.0 - frequency_penalty * i if i in {1, 2} else 1.0
                                for i in range(self.vocab_size)
                            ],
                        ],
                        dtype=torch.float32,
                    ),
                ),
            ],
        )

83
    def create_test_subjects(self) -> List[Subject]:
84
        self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
85
86
87
        self.disabled = self._create_subject(frequency_penalty=0.0)


88
89
90
91
92
93
94
95
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
    frequency_penalty = 0.12


class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
    frequency_penalty = -0.12


96
97
if __name__ == "__main__":
    unittest.main()