Unverified Commit ac2324c1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Skip the flaky test_stateful_custom_logit_processor (#6251)

parent ef8ec07b
......@@ -28,11 +28,26 @@ class CustomLogitProcessor(ABC):
"""Define the callable behavior."""
raise NotImplementedError
def to_str(self) -> str:
@classmethod
def to_str(cls) -> str:
"""Serialize the callable function to a JSON-compatible string."""
return json.dumps({"callable": dill.dumps(self).hex()})
return json.dumps({"callable": dill.dumps(cls).hex()})
@classmethod
def from_str(cls, json_str: str):
"""Deserialize a callable function from a JSON string."""
return _cache_from_str(json_str)
return _cache_from_str(json_str)()
class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
def __call__(
self,
logits: torch.Tensor,
custom_param_list: Optional[List[Dict[str, Any]]] = None,
) -> torch.Tensor:
disallowed_token_ids = custom_param_list[0]["token_ids"]
assert all(
disallowed_token_ids == c["token_ids"] for c in custom_param_list
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits
......@@ -344,9 +344,7 @@ class TestSRTEndpoint(CustomTestCase):
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
custom_json["custom_logit_processor"] = DeterministicLogitProcessor.to_str()
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
......@@ -373,7 +371,6 @@ class TestSRTEndpoint(CustomTestCase):
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
......@@ -447,10 +444,22 @@ class TestSRTEndpoint(CustomTestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
@unittest.skip("Skip this test because this feature has a bug. See comments below.")
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
"""
NOTE: This feature has a race condition bug.
This line https://github.com/sgl-project/sglang/blob/ef8ec07b2ce4c70c2a33ec5acda4ce529bc3cda4/test/srt/test_srt_endpoint.py#L395-L396 can be accessed by two concurrent threads at the same time. The access order is not guaranteed.
In sglang, we use two python threads to overlap the GPU computation and CPU scheduling.
Thread 1 (the CPU scheduling thread) will update the `param_dict["__req__"].output_ids`.
Thread 2 (the GPU computation thread) will call `DeterministicStatefulLogitProcessor` because sampling is considered as GPU computation.
We can fix this by moving the call of DeterministicStatefulLogitProcessor to the CPU scheduling thread.
"""
self.run_stateful_custom_logit_processor(first_token_id=5)
@unittest.skip("Skip this test because this feature has a bug. See comments above.")
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment