Unverified Commit 583697cd authored by Hongpeng Guo's avatar Hongpeng Guo Committed by GitHub
Browse files

[Enhancement] Custom Logit Processor Improvement (#2998)


Signed-off-by: default avatarHongpeng Guo <hpguo@anyscale.com>
parent 2584f6d9
...@@ -232,6 +232,7 @@ def extend(reqs, model_runner): ...@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config, model_config=model_runner.model_config,
enable_overlap=False, enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
) )
batch.prepare_for_extend() batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
......
...@@ -132,6 +132,11 @@ class Sampler(nn.Module): ...@@ -132,6 +132,11 @@ class Sampler(nn.Module):
"""Apply custom logit processors to the logits. """Apply custom logit processors to the logits.
This function will modify the logits in-place.""" This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, ( for _, (
processor, processor,
batch_mask, batch_mask,
...@@ -139,6 +144,11 @@ class Sampler(nn.Module): ...@@ -139,6 +144,11 @@ class Sampler(nn.Module):
# Get the batch indices that need to be processed # Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0] batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits # Apply the processor to the logits
logits[batch_mask] = processor( logits[batch_mask] = processor(
logits[batch_mask], logits[batch_mask],
......
...@@ -595,6 +595,9 @@ class ScheduleBatch: ...@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None spec_info: Optional[SpecInfo] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
@classmethod @classmethod
def init_new( def init_new(
cls, cls,
...@@ -605,6 +608,7 @@ class ScheduleBatch: ...@@ -605,6 +608,7 @@ class ScheduleBatch:
model_config: ModelConfig, model_config: ModelConfig,
enable_overlap: bool, enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm, spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -618,6 +622,7 @@ class ScheduleBatch: ...@@ -618,6 +622,7 @@ class ScheduleBatch:
has_grammar=any(req.grammar for req in reqs), has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
) )
def batch_size(self): def batch_size(self):
...@@ -1201,6 +1206,7 @@ class ScheduleBatch: ...@@ -1201,6 +1206,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
) )
def __str__(self): def __str__(self):
......
...@@ -966,6 +966,7 @@ class Scheduler: ...@@ -966,6 +966,7 @@ class Scheduler:
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
...@@ -1520,6 +1521,7 @@ class Scheduler: ...@@ -1520,6 +1521,7 @@ class Scheduler:
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
) )
idle_batch.prepare_for_idle() idle_batch.prepare_for_idle()
return idle_batch return idle_batch
......
...@@ -89,7 +89,10 @@ class SamplingBatchInfo: ...@@ -89,7 +89,10 @@ class SamplingBatchInfo:
).to(device, non_blocking=True) ).to(device, non_blocking=True)
# Check if any request has custom logit processor # Check if any request has custom logit processor
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs) has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
if has_custom_logit_processor: if has_custom_logit_processor:
# Merge the same type of custom logit processors together # Merge the same type of custom logit processors together
...@@ -247,8 +250,7 @@ class SamplingBatchInfo: ...@@ -247,8 +250,7 @@ class SamplingBatchInfo:
self, unfinished_indices: List[int], new_indices: torch.Tensor self, unfinished_indices: List[int], new_indices: torch.Tensor
): ):
"""Filter the custom logit processor and custom params""" """Filter the custom logit processor and custom params"""
if not self.custom_logit_processor:
return
self.custom_logit_processor = { self.custom_logit_processor = {
k: (p, mask[new_indices]) k: (p, mask[new_indices])
for k, (p, mask) in self.custom_logit_processor.items() for k, (p, mask) in self.custom_logit_processor.items()
...@@ -258,7 +260,9 @@ class SamplingBatchInfo: ...@@ -258,7 +260,9 @@ class SamplingBatchInfo:
} }
self.custom_params = [self.custom_params[i] for i in unfinished_indices] self.custom_params = [self.custom_params[i] for i in unfinished_indices]
if len(self) == 0: # If the custom logit processor is an empty dict, set the flag to False,
# and set the custom logit processor and custom params to None.
if len(self.custom_logit_processor) == 0:
self.custom_logit_processor = None self.custom_logit_processor = None
self.custom_params = None self.custom_params = None
self.has_custom_logit_processor = False self.has_custom_logit_processor = False
...@@ -290,8 +294,8 @@ class SamplingBatchInfo: ...@@ -290,8 +294,8 @@ class SamplingBatchInfo:
@staticmethod @staticmethod
def merge_custom_logit_processor( def merge_custom_logit_processor(
lhs: Optional[Dict[str, torch.Tensor]], lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
rhs: Optional[Dict[str, torch.Tensor]], rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
bs1: int, bs1: int,
bs2: int, bs2: int,
device: str, device: str,
...@@ -319,27 +323,22 @@ class SamplingBatchInfo: ...@@ -319,27 +323,22 @@ class SamplingBatchInfo:
) )
merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
assert merged_dict[k][1].shape[0] == bs1 + bs2, (
f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
f"\n{lhs=}\n{rhs=}"
)
return merged_dict return merged_dict
def merge_batch(self, other: "SamplingBatchInfo"): def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator) self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [ # Merge the logit bias tensor
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
# Merge the custom logit processors and custom params lists # Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor: if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors # Merge the custom logit processors
...@@ -360,6 +359,22 @@ class SamplingBatchInfo: ...@@ -360,6 +359,22 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor # Set the flag to True if any of the two has custom logit processor
self.has_custom_logit_processor = True self.has_custom_logit_processor = True
# Note: becasue the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below.
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
self_val = getattr(self, item, None)
other_val = getattr(other, item, None)
setattr(self, item, torch.concat([self_val, other_val]))
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor): def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias # Apply logit_bias
if self.logit_bias is not None: if self.logit_bias is not None:
......
...@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ ...@@ -4,8 +4,10 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
""" """
import json import json
import random
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np import numpy as np
import requests import requests
...@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -253,8 +255,11 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertTrue(all(x is not None for x in logprobs)) self.assertTrue(all(x is not None for x in logprobs))
def run_custom_logit_processor(self, target_token_id: int): def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
"""Test custom logit processor with custom params.""" """Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": target_token_id} custom_params = {"token_id": target_token_id}
...@@ -285,7 +290,11 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -285,7 +290,11 @@ class TestSRTEndpoint(unittest.TestCase):
# Custom json data with custom logit processor and params. # Custom json data with custom logit processor and params.
custom_json = base_json.copy() custom_json = base_json.copy()
custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str() # Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post( custom_response = requests.post(
...@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -297,22 +306,30 @@ class TestSRTEndpoint(unittest.TestCase):
sampled_tokens = [x[1] for x in output_token_logprobs] sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic. # The logit processor should always sample the given token as the logits is deterministic.
self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens)) if target_token_id is not None:
self.assertTrue(
all(x == custom_params["token_id"] for x in sampled_tokens),
# Print the detailed test case info if the test fails.
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self): def test_custom_logit_processor(self):
"""Test custom logit processor with a single request.""" """Test custom logit processor with a single request."""
# Temporarily skipped due to buggy implementation
return
self.run_custom_logit_processor(target_token_id=5) self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self): def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests.""" """Test custom logit processor with a batch of requests."""
# Temporarily skipped due to buggy implementation
return
target_token_ids = list(range(32)) target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor: with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids)) list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_get_server_info(self): def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info") response = requests.get(self.base_url + "/get_server_info")
response_json = response.json() response_json = response.json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment