Unverified Commit ca929118 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

[Feature] Add Logit Bias (#6579)


Co-authored-by: default avatarCinjon Resnick <cinjon.resnick@gmail.com>
parent 344adb00
...@@ -582,6 +582,7 @@ def v1_generate_request( ...@@ -582,6 +582,7 @@ def v1_generate_request(
"no_stop_trim": request.no_stop_trim, "no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens, "skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
} }
) )
return_logprobs.append(request.logprobs is not None) return_logprobs.append(request.logprobs is not None)
...@@ -1219,6 +1220,7 @@ def v1_chat_generate_request( ...@@ -1219,6 +1220,7 @@ def v1_chat_generate_request(
"no_stop_trim": request.no_stop_trim, "no_stop_trim": request.no_stop_trim,
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens, "skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
} }
if request.response_format and request.response_format.type == "json_schema": if request.response_format and request.response_format.type == "json_schema":
......
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.utils import merge_bias_tensor
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
...@@ -63,6 +64,9 @@ class SamplingBatchInfo: ...@@ -63,6 +64,9 @@ class SamplingBatchInfo:
# Device # Device
device: str = "cuda" device: str = "cuda"
# Handle logit bias
logit_bias: Optional[torch.Tensor] = None
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs reqs = batch.reqs
...@@ -85,6 +89,14 @@ class SamplingBatchInfo: ...@@ -85,6 +89,14 @@ class SamplingBatchInfo:
[r.sampling_params.min_p for r in reqs], dtype=torch.float [r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True) ).to(device, non_blocking=True)
logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs):
logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
for i, r in enumerate(reqs):
if r.sampling_params.logit_bias is not None:
for key, value in r.sampling_params.logit_bias.items():
logit_bias[i, int(key)] = value
# Check if any request has custom logit processor # Check if any request has custom logit processor
has_custom_logit_processor = ( has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first. batch.enable_custom_logit_processor # check the flag first.
...@@ -150,6 +162,7 @@ class SamplingBatchInfo: ...@@ -150,6 +162,7 @@ class SamplingBatchInfo:
custom_params=custom_params, custom_params=custom_params,
custom_logit_processor=merged_custom_logit_processor, custom_logit_processor=merged_custom_logit_processor,
device=device, device=device,
logit_bias=logit_bias,
) )
return ret return ret
...@@ -206,6 +219,9 @@ class SamplingBatchInfo: ...@@ -206,6 +219,9 @@ class SamplingBatchInfo:
if self.vocab_mask is not None: if self.vocab_mask is not None:
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask) self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
if self.logit_bias is not None:
logits.add_(self.logit_bias)
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor): def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
self.penalizer_orchestrator.filter(keep_indices_device) self.penalizer_orchestrator.filter(keep_indices_device)
...@@ -221,6 +237,9 @@ class SamplingBatchInfo: ...@@ -221,6 +237,9 @@ class SamplingBatchInfo:
value = getattr(self, item, None) value = getattr(self, item, None)
setattr(self, item, value[keep_indices_device]) setattr(self, item, value[keep_indices_device])
if self.logit_bias is not None:
self.logit_bias = self.logit_bias[keep_indices_device]
def _filter_batch_custom_logit_processor( def _filter_batch_custom_logit_processor(
self, keep_indices: List[int], keep_indices_device: torch.Tensor self, keep_indices: List[int], keep_indices_device: torch.Tensor
): ):
...@@ -321,3 +340,8 @@ class SamplingBatchInfo: ...@@ -321,3 +340,8 @@ class SamplingBatchInfo:
self.need_top_p_sampling |= other.need_top_p_sampling self.need_top_p_sampling |= other.need_top_p_sampling
self.need_top_k_sampling |= other.need_top_k_sampling self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
# Merge logit bias
self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
)
...@@ -52,6 +52,7 @@ class SamplingParams: ...@@ -52,6 +52,7 @@ class SamplingParams:
no_stop_trim: bool = False, no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
stream_interval: Optional[int] = None, stream_interval: Optional[int] = None,
logit_bias: Optional[Dict[str, float]] = None,
) -> None: ) -> None:
self.max_new_tokens = max_new_tokens self.max_new_tokens = max_new_tokens
self.stop_strs = stop self.stop_strs = stop
...@@ -78,6 +79,7 @@ class SamplingParams: ...@@ -78,6 +79,7 @@ class SamplingParams:
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.custom_params = custom_params self.custom_params = custom_params
self.stream_interval = stream_interval self.stream_interval = stream_interval
self.logit_bias = logit_bias
# Process some special cases # Process some special cases
if 0 <= self.temperature < _SAMPLING_EPS: if 0 <= self.temperature < _SAMPLING_EPS:
......
...@@ -2210,6 +2210,45 @@ class Withable(Generic[T]): ...@@ -2210,6 +2210,45 @@ class Withable(Generic[T]):
self._value = None self._value = None
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]: def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
import huggingface_hub as hf import huggingface_hub as hf
......
...@@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase): ...@@ -504,6 +504,122 @@ class TestSRTEndpoint(CustomTestCase):
version = response_json["version"] version = response_json["version"]
self.assertIsInstance(version, str) self.assertIsInstance(version, str)
def test_logit_bias(self):
"""Test that a very high logit bias forces sampling of a specific token."""
# Choose a token ID to bias (using 5 as an example)
target_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias = {str(target_token_id): 100.0} # Very high positive bias
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0, # Use high temperature to encourage exploration
"max_new_tokens": 4,
"logit_bias": logit_bias,
},
"return_logprob": True,
},
)
response_json = response.json()
# Extract the sampled token IDs from the output
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# Verify that all sampled tokens are the target token
self.assertTrue(
all(x == target_token_id for x in sampled_tokens),
f"Expected all tokens to be {target_token_id}, but got {sampled_tokens}",
)
def test_forbidden_token(self):
"""Test that a forbidden token (very negative logit bias) doesn't appear in the output."""
# Choose a token ID to forbid (using 10 as an example)
forbidden_token_id = 23994 # rice for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
logit_bias = {
str(forbidden_token_id): -100.0
} # Very negative bias to forbid the token
response = requests.post(
self.base_url + "/generate",
json={
"text": "Only output 'rice' exactly like this, in lowercase ONLY: rice",
"sampling_params": {
"temperature": 1.0, # Use high temperature to encourage diverse output
"max_new_tokens": 50, # Generate enough tokens to likely include numbers
"logit_bias": logit_bias,
},
"return_logprob": True,
},
)
response_json = response.json()
# Extract the sampled token IDs from the output
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# Verify that the forbidden token doesn't appear in the output
self.assertNotIn(
forbidden_token_id,
sampled_tokens,
f"Expected forbidden token {forbidden_token_id} not to be present, but it was found",
)
def test_logit_bias_isolation(self):
"""Test that logit_bias applied to one request doesn't affect other requests in batch."""
# Choose a token ID to bias in first request only
biased_token_id = 60704 # Paris for meta-llama/Llama-3.2-1B-Instruct, DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Prepare batch requests - one with logit_bias and one without
requests_data = [
{
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 4,
"logit_bias": {str(biased_token_id): 100.0}, # Strong bias
},
"return_logprob": True,
},
{
"text": "The capital of France is",
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 4,
},
"return_logprob": True,
},
]
# Send both requests
responses = []
for req in requests_data:
response = requests.post(self.base_url + "/generate", json=req)
responses.append(response.json())
# Extract token IDs from each response
biased_tokens = [
x[1] for x in responses[0]["meta_info"]["output_token_logprobs"]
]
unbiased_tokens = [
x[1] for x in responses[1]["meta_info"]["output_token_logprobs"]
]
# Verify first response contains only biased tokens
self.assertTrue(
all(x == biased_token_id for x in biased_tokens),
f"Expected all tokens to be {biased_token_id} in first response, but got {biased_tokens}",
)
# Verify second response contains at least some different tokens
# (We can't guarantee exactly what tokens will be generated, but they shouldn't all be the biased token)
self.assertTrue(
any(x != biased_token_id for x in unbiased_tokens),
f"Expected some tokens to be different from {biased_token_id} in second response, but got {unbiased_tokens}",
)
def test_get_server_info_concurrent(self): def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server.""" """Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30) tp = ThreadPoolExecutor(max_workers=30)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment