test_presence_penalty.py 3.11 KB
Newer Older
1
import unittest
2
from typing import List
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import torch

from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
    BatchedPresencePenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
    BaseBatchedPenalizerTest,
    MockSamplingParams,
    Step,
    StepType,
    Subject,
)


18
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
19
    Penalizer = BatchedPresencePenalizer
20
21
22
23
24
25
26
    presence_penalty: float

    def setUp(self):
        if self.__class__ == BaseBatchedPresencePenalizerTest:
            self.skipTest("Base class for presence_penalty tests")

        super().setUp()
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

    def _create_subject(self, presence_penalty: float) -> Subject:
        return Subject(
            sampling_params=MockSamplingParams(
                presence_penalty=presence_penalty,
            ),
            steps=[
                Step(
                    type=StepType.INPUT,
                    token_ids=[0, 1, 2],
                    expected_tensors={
                        "presence_penalties": self.tensor(
                            [[presence_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_presence_penalties": self.tensor(
                            [[0.0] * self.vocab_size], dtype=torch.float32
                        ),
                    },
                    expected_logits=self.tensor(
                        [[1] * self.vocab_size], dtype=torch.float32
                    ),
                ),
                Step(
                    type=StepType.OUTPUT,
                    token_ids=[1, 2, 2],
                    expected_tensors={
                        "presence_penalties": self.tensor(
                            [[presence_penalty] * self.vocab_size], dtype=torch.float32
                        ),
                        "cumulated_presence_penalties": self.tensor(
                            [
                                [
                                    presence_penalty if i in {1, 2} else 0.0
                                    for i in range(self.vocab_size)
                                ],
                            ],
                            dtype=torch.float32,
                        ),
                    },
                    expected_logits=self.tensor(
                        [
                            [
                                1.0 - presence_penalty if i in {1, 2} else 1.0
                                for i in range(self.vocab_size)
                            ],
                        ],
                        dtype=torch.float32,
                    ),
                ),
            ],
        )

79
    def create_test_subjects(self) -> List[Subject]:
80
        self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
81
82
83
        self.disabled = self._create_subject(presence_penalty=0.0)


84
85
86
87
88
89
90
91
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
    presence_penalty = 0.12


class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
    presence_penalty = -0.12


92
93
if __name__ == "__main__":
    unittest.main()