test_frequency_penalty.py 3.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import typing
import unittest

import torch

from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
    BatchedFrequencyPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
    BaseBatchedPenalizerTest,
    MockSamplingParams,
    Step,
    StepType,
    Subject,
)


18
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
19
    Penalizer = BatchedFrequencyPenalizer
20
21
22
23
24
25
26
    frequency_penalty: float

    def setUp(self):
        if self.__class__ == BaseBatchedFrequencyPenalizerTest:
            self.skipTest("Base class for frequency_penalty tests")

        super().setUp()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

    def _create_subject(self, frequency_penalty: float) -> Subject:
        return Subject(
            sampling_params=MockSamplingParams(
                frequency_penalty=frequency_penalty,
            ),
            steps=[
                Step(
                    type=StepType.INPUT,
                    token_ids=[0, 1, 2],
                    expected_tensors={
                        "frequency_penalties": self.tensor(
                            [[frequency_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_frequency_penalties": self.tensor(
                            [[0.0] * self.vocab_size], dtype=torch.float32
                        ),
                    },
                    expected_logits=self.tensor(
                        [[1] * self.vocab_size], dtype=torch.float32
                    ),
                ),
                Step(
                    type=StepType.OUTPUT,
                    token_ids=[1, 2, 2],
                    expected_tensors={
                        "frequency_penalties": self.tensor(
                            [[frequency_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_frequency_penalties": self.tensor(
                            [
                                [
                                    frequency_penalty * i if i in {1, 2} else 0.0
                                    for i in range(self.vocab_size)
                                ],
                            ],
                            dtype=torch.float32,
                        ),
                    },
                    expected_logits=self.tensor(
                        [
                            [
                                1.0 - frequency_penalty * i if i in {1, 2} else 1.0
                                for i in range(self.vocab_size)
                            ],
                        ],
                        dtype=torch.float32,
                    ),
                ),
            ],
        )

    def create_test_subjects(self) -> typing.List[Subject]:
80
        self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
81
82
83
        self.disabled = self._create_subject(frequency_penalty=0.0)


84
85
86
87
88
89
90
91
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
    frequency_penalty = 0.12


class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
    frequency_penalty = -0.12


92
93
if __name__ == "__main__":
    unittest.main()