Unverified Commit 57a404fd authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Remove outdated test utils and fix links for the doc of sampling params (#3999)

parent 2796fbb5
......@@ -94,7 +94,7 @@ def gen(
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
if choices:
return SglSelect(
......
......@@ -457,7 +457,7 @@ class SglGen(SglExpr):
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
"""Call the model to generate. See the meaning of the arguments in docs/backend/sampling_params.md"""
super().__init__()
self.name = name
self.sampling_params = SglSamplingParams(
......
......@@ -22,7 +22,7 @@ class SamplingParams:
"""
The sampling parameters.
See docs/references/sampling_params.md or
See docs/backend/sampling_params.md or
https://docs.sglang.ai/backend/sampling_params.html
for the documentation.
"""
......
import dataclasses
import enum
import unittest
from typing import Dict, List, Optional, Set, Tuple, Type
import torch
from sglang.srt.sampling.penaltylib.orchestrator import (
BatchedPenalizerOrchestrator,
_BatchedPenalizer,
_BatchLike,
)
@dataclasses.dataclass
class MockSamplingParams:
frequency_penalty: float = 0.0
min_new_tokens: int = 0
stop_token_ids: List[int] = None
presence_penalty: float = 0.0
repetition_penalty: float = 1.0
@dataclasses.dataclass
class MockTokenizer:
eos_token_id: int
additional_stop_token_ids: Optional[List[int]] = None
@dataclasses.dataclass
class MockReq:
origin_input_ids: List[int]
sampling_params: MockSamplingParams
tokenizer: MockTokenizer
class StepType(enum.Enum):
INPUT = "input"
OUTPUT = "output"
@dataclasses.dataclass
class Step:
type: StepType
token_ids: List[int]
expected_tensors: Dict[str, torch.Tensor]
# assume initial logits are all 1
expected_logits: torch.Tensor
@dataclasses.dataclass
class Subject:
sampling_params: MockSamplingParams
# first step must be input, which will be converted to Req
steps: List[Step]
eos_token_id: int = -1
def __post_init__(self):
if self.steps[0].type != StepType.INPUT:
raise ValueError("First step must be input")
# each steps should have the same expected_tensors.keys()
for i in range(1, len(self.steps)):
if self.tensor_keys(i) != self.tensor_keys():
raise ValueError(
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
)
def tensor_keys(self, i: int = 0) -> Set[str]:
return set(self.steps[i].expected_tensors.keys())
def to_req(self) -> MockReq:
return MockReq(
origin_input_ids=self.steps[0].token_ids,
sampling_params=self.sampling_params,
tokenizer=MockTokenizer(eos_token_id=self.eos_token_id),
)
@dataclasses.dataclass
class Case:
enabled: bool
test_subjects: List[Subject]
def __post_init__(self):
# each test_subjects.steps should have the same expected_tensors.keys()
for i in range(1, len(self.test_subjects)):
if self.tensor_keys(i) != self.tensor_keys():
raise ValueError(
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
)
def tensor_keys(self, i: int = 0) -> List[str]:
return set(self.test_subjects[i].tensor_keys())
class BaseBatchedPenalizerTest(unittest.TestCase):
Penalizer: Type[_BatchedPenalizer]
device = "cuda"
vocab_size = 5
enabled: Subject = None
disabled: Subject = None
def setUp(self):
if self.__class__ == BaseBatchedPenalizerTest:
self.skipTest("Base class for penalizer tests")
self.create_test_subjects()
self.create_test_cases()
def tensor(self, data, **kwargs) -> torch.Tensor:
"""
Shortcut to create a tensor with device=self.device.
"""
return torch.tensor(data, **kwargs, device=self.device)
def create_test_subjects(self) -> List[Subject]:
raise NotImplementedError()
def create_test_cases(self):
self.test_cases = [
Case(enabled=True, test_subjects=[self.enabled]),
Case(enabled=False, test_subjects=[self.disabled]),
Case(enabled=True, test_subjects=[self.enabled, self.disabled]),
]
def _create_penalizer(
self, case: Case
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
orchestrator = BatchedPenalizerOrchestrator(
vocab_size=self.vocab_size,
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
device=self.device,
Penalizers={self.Penalizer},
)
return orchestrator, orchestrator.penalizers[self.Penalizer]
def test_is_required(self):
for case in self.test_cases:
with self.subTest(case=case):
_, penalizer = self._create_penalizer(case)
self.assertEqual(case.enabled, penalizer.is_required())
def test_prepare(self):
for case in self.test_cases:
with self.subTest(case=case):
orchestrator, penalizer = self._create_penalizer(case)
self.assertEqual(case.enabled, penalizer.is_prepared())
if case.enabled:
for key, tensor in {
key: torch.cat(
tensors=[
subject.steps[0].expected_tensors[key]
for subject in case.test_subjects
],
)
for key in case.tensor_keys()
}.items():
torch.testing.assert_close(
actual=getattr(penalizer, key),
expected=tensor,
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
original = torch.ones(
size=(len(case.test_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual = orchestrator.apply(original.clone())
expected = torch.cat(
tensors=[
subject.steps[0].expected_logits
for subject in case.test_subjects
],
)
if actual is None:
actual = original
torch.testing.assert_close(
actual=actual,
expected=expected,
msg=f"logits\nactual={actual}\nexpected={expected}",
)
def test_teardown(self):
for case in self.test_cases:
with self.subTest(case=case):
_, penalizer = self._create_penalizer(case)
penalizer.teardown()
for key in case.test_subjects[0].steps[0].expected_tensors.keys():
self.assertIsNone(getattr(penalizer, key, None))
def test_filter(self):
for case in self.test_cases:
with self.subTest(case=case):
orchestrator, penalizer = self._create_penalizer(case)
indices_to_keep = [0]
orchestrator.filter(indices_to_keep=indices_to_keep)
filtered_subjects = [case.test_subjects[i] for i in indices_to_keep]
if penalizer.is_required():
self.assertTrue(penalizer.is_prepared())
for key, tensor in {
key: torch.cat(
tensors=[
subject.steps[0].expected_tensors[key]
for subject in filtered_subjects
],
)
for key in case.tensor_keys()
}.items():
torch.testing.assert_close(
actual=getattr(penalizer, key),
expected=tensor,
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
actual_logits = orchestrator.apply(
torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
)
if actual_logits is None:
continue
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[0].expected_logits
for subject in filtered_subjects
],
)
torch.testing.assert_close(
actual=actual_logits,
expected=filtered_expected_logits,
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
)
def test_merge_enabled_with_disabled(self):
enabled_test_case = self.test_cases[0]
disabled_test_case = self.test_cases[1]
orchestrator, penalizer = self._create_penalizer(enabled_test_case)
theirs, _ = self._create_penalizer(disabled_test_case)
orchestrator.merge(theirs)
for key, tensor in {
key: torch.cat(
tensors=[
enabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
disabled_test_case.test_subjects[0].steps[0].expected_tensors[key],
],
)
for key in enabled_test_case.tensor_keys()
}.items():
torch.testing.assert_close(
actual=getattr(penalizer, key),
expected=tensor,
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
def test_cumulate_apply_repeat(self):
for case in self.test_cases:
with self.subTest(case=case):
orchestrator, penalizer = self._create_penalizer(case)
max_step = max(len(subject.steps) for subject in case.test_subjects)
for i in range(1, max_step):
orchestrator.filter(
indices_to_keep=[
j
for j, subject in enumerate(case.test_subjects)
if i < len(subject.steps)
]
)
filtered_subjects = [
subject
for subject in case.test_subjects
if i < len(subject.steps)
]
inputs: List[List[int]] = []
outputs: List[List[int]] = []
for subject in filtered_subjects:
step = subject.steps[i]
if step.type == StepType.INPUT:
raise NotImplementedError()
else:
inputs.append([])
outputs.append(step.token_ids)
if any(outputs):
for j in range(max(len(x) for x in outputs)):
tmp_outputs = torch.tensor(
[x[j] for x in outputs],
dtype=torch.int32,
device=orchestrator.device,
)
orchestrator.cumulate_output_tokens(tmp_outputs)
if penalizer.is_required():
self.assertTrue(penalizer.is_prepared())
for key, tensor in {
key: torch.cat(
tensors=[
subject.steps[i].expected_tensors[key]
for subject in filtered_subjects
],
)
for key in case.tensor_keys()
}.items():
torch.testing.assert_close(
actual=getattr(penalizer, key),
expected=tensor,
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
)
original = torch.ones(
size=(len(filtered_subjects), self.vocab_size),
dtype=torch.float32,
device=self.device,
)
actual_logits = orchestrator.apply(original.clone())
filtered_expected_logits = torch.cat(
tensors=[
subject.steps[i].expected_logits
for subject in filtered_subjects
],
)
if actual_logits is None:
actual_logits = original
torch.testing.assert_close(
actual=actual_logits,
expected=filtered_expected_logits,
msg=f"logits\nactual={actual_logits}\nexpected={filtered_expected_logits}",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment