Unverified Commit 95a28019 authored by Juwan Yoo's avatar Juwan Yoo Committed by GitHub
Browse files

test: negative value testing for frequency, presence penalizers (#995)

parent e040a245
...@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import ( ...@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import (
Subject, Subject,
) )
FREQUENCY_PENALTY = 0.12
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedFrequencyPenalizer Penalizer = BatchedFrequencyPenalizer
frequency_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
self.skipTest("Base class for frequency_penalty tests")
super().setUp()
def _create_subject(self, frequency_penalty: float) -> Subject: def _create_subject(self, frequency_penalty: float) -> Subject:
return Subject( return Subject(
...@@ -72,9 +77,17 @@ class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest): ...@@ -72,9 +77,17 @@ class TestBatchedFrequencyPenalizer(BaseBatchedPenalizerTest):
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> typing.List[Subject]:
self.enabled = self._create_subject(frequency_penalty=FREQUENCY_PENALTY) self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0) self.disabled = self._create_subject(frequency_penalty=0.0)
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = 0.12
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = -0.12
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import ( ...@@ -14,11 +14,16 @@ from sglang.test.srt.sampling.penaltylib.utils import (
Subject, Subject,
) )
PRESENCE_PENALTY = 0.12
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedPresencePenalizer Penalizer = BatchedPresencePenalizer
presence_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedPresencePenalizerTest:
self.skipTest("Base class for presence_penalty tests")
super().setUp()
def _create_subject(self, presence_penalty: float) -> Subject: def _create_subject(self, presence_penalty: float) -> Subject:
return Subject( return Subject(
...@@ -72,9 +77,17 @@ class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest): ...@@ -72,9 +77,17 @@ class TestBatchedPresencePenalizer(BaseBatchedPenalizerTest):
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> typing.List[Subject]:
self.enabled = self._create_subject(presence_penalty=PRESENCE_PENALTY) self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0) self.disabled = self._create_subject(presence_penalty=0.0)
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
presence_penalty = 0.12
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
presence_penalty = -0.12
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment