Unverified Commit 95a28019 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

test: negative value testing for frequency, presence penalizers (#995)

parent e040a245
......@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import (
Subject,
)
FREQUENCY_PENALTY = 0.12
class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest):
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedFrequencyPenalizer
frequency_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
self.skipTest("Base class for frequency_penalty tests")
super().setUp()
def _create_subject(self, frequency_penalty: float) -> Subject:
return Subject(
......@@ -72,9 +77,17 @@ class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest):
)
def create_test_subjects(self) -> typing.List[Subject]:
self.enabled = self._create_subject(frequency_penalty=FREQUENCY_PENALTY)
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0)
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = 0.12
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = -0.12
if __name__ == "__main__":
unittest.main()
......@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import (
Subject,
)
PRESENCE_PENALTY = 0.12
class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedPresencePenalizer
presence_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedPresencePenalizerTest:
self.skipTest("Base class for presence_penalty tests")
super().setUp()
def _create_subject(self, presence_penalty: float) -> Subject:
return Subject(
......@@ -72,9 +77,17 @@ class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
)
def create_test_subjects(self) -> typing.List[Subject]:
self.enabled = self._create_subject(presence_penalty=PRESENCE_PENALTY)
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0)
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
presence_penalty = 0.12
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
presence_penalty = -0.12
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment