"gallery/vscode:/vscode.git/clone" did not exist on "11e49de410ec84ec669293a91dfaa13a53c9bc47"
Unverified Commit 583697cd authored by Hongpeng Guo's avatar Hongpeng Guo Committed by GitHub
Browse files

[Enhancement] Custom Logit Processor Improvement (#2998)


Signed-off-by: default avatarHongpeng Guo <hpguo@anyscale.com>
parent 2584f6d9
......@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
......
......@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, (
processor,
batch_mask,
......@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
......
......@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@classmethod
def init_new(
cls,
......@@ -605,6 +608,7 @@ class ScheduleBatch:
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
):
return cls(
reqs=reqs,
......@@ -618,6 +622,7 @@ class ScheduleBatch:
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
)
def batch_size(self):
......@@ -1201,6 +1206,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
)
def __str__(self):
......
......@@ -966,6 +966,7 @@ class Scheduler:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
new_batch.prepare_for_extend()
......@@ -1520,6 +1521,7 @@ class Scheduler:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch
......
......@@ -89,7 +89,10 @@ class SamplingBatchInfo:
).to(device, non_blocking=True)
# Check if any request has custom logit processor
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
......@@ -247,8 +250,7 @@ class SamplingBatchInfo:
self, unfinished_indices: List[int], new_indices: torch.Tensor
):
"""Filter the custom logit processor and custom params"""
if not self.custom_logit_processor:
return
self.custom_logit_processor = {
k: (p, mask[new_indices])
for k, (p, mask) in self.custom_logit_processor.items()
......@@ -258,7 +260,9 @@ class SamplingBatchInfo:
}
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
if len(self) == 0:
# If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
if len(self.custom_logit_processor) == 0:
self.custom_logit_processor = None
self.custom_params = None
self.has_custom_logit_processor = False
......@@ -290,8 +294,8 @@ class SamplingBatchInfo:
@staticmethod
def merge_custom_logit_processor(
lhs: Optional[Dict[str, torch.Tensor]],
rhs: Optional[Dict[str, torch.Tensor]],
lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
bs1: int,
bs2: int,
device: str,
......@@ -319,27 +323,22 @@ class SamplingBatchInfo:
)
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
f"\n{lhs=}\n{rhs=}"
)
return merged_dict
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
# Merge the logit bias tensor
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors
......@@ -360,6 +359,22 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor
self.has_custom_logit_processor = True
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below.
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
......
......@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
"""
import json
import random
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np
import requests
......@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertTrue(all(x is not None for x in logprobs))
def run_custom_logit_processor(self, target_token_id: int):
"""Test custom logit processor with custom params."""
def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
"""Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": target_token_id}
......@@ -285,8 +290,12 @@ class TestSRTEndpoint(unittest.TestCase):
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
custom_json["sampling_params"]["custom_params"] = custom_params
# Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
......@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
if target_token_id is not None:
self.assertTrue(
all(x == custom_params["token_id"] for x in sampled_tokens),
# Print the detailed test case info if the test fails.
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
# Temporarily skipped due to buggy implementation
return
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
# Temporarily skipped due to buggy implementation
return
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment