Unverified Commit 6224a9f6 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Support logit_bias in v1 Sampler (#13079)

parent 085b7b2d
# SPDX-License-Identifier: Apache-2.0
from typing import List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple
import numpy as np
import pytest
......@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)
def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res
def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
......@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata
......@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
......@@ -45,9 +45,11 @@ def _remove_requests(
def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState], req_ids_retained: Set[int],
reqs: List[CachedRequestState],
req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int],
device: torch.device) -> SamplingMetadata:
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
......@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)]
logit_bias = [None] * num_reqs
for req in reqs:
if req.req_id not in req_ids_retained:
continue
......@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[
index_in_input_batch] = req.sampling_params.frequency_penalty
repetition_penalties[
index_in_input_batch] = req.sampling_params.repetition_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
......@@ -93,32 +97,34 @@ def _construct_expected_sampling_metadata(
no_top_k=all(x == 0 for x in top_k),
generators={},
max_num_logprobs=0,
prompt_token_ids= make_tensor_with_pad(
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(
frequency_penalties, dtype=torch.float,
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(
presence_penalties, dtype=torch.float,
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(
repetition_penalties, dtype=torch.float,
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x ==0 for x in presence_penalties) and \
all(x ==0 for x in frequency_penalties) and \
all(x ==1 for x in repetition_penalties))
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
)
def _create_sampling_params():
return SamplingParams(top_k=np.random.randint(1, 10),
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
......@@ -127,7 +133,9 @@ def _create_sampling_params():
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
])
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int):
......@@ -139,7 +147,8 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
prompt=None,
sampling_params=_create_sampling_params(),
......@@ -148,7 +157,8 @@ def _construct_cached_request_state(req_id_suffix: int):
block_ids=[],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids)
output_token_ids=output_token_ids,
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
......@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024)
vocab_size=1024,
)
reqs: List[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
......@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties)
assert torch.allclose(expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties)
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert (
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
assert (expected_sampling_metadata.stop_token_ids ==
sampling_metadata.stop_token_ids)
assert (expected_sampling_metadata.no_penalties ==
sampling_metadata.no_penalties)
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.stop_token_ids == \
sampling_metadata.stop_token_ids
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
......@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams":
if logit_bias is not None:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias = {
int(token): bias
int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items()
}
......
......@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]
logit_bias: List[Optional[Dict[int, float]]]
......@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature.
......@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits
def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits
......@@ -141,8 +141,8 @@ class InputBatch:
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = \
self.presence_penalties_cpu_tensor.numpy()
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures
......@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
def add_request(
self,
request: "CachedRequestState",
......@@ -220,16 +223,16 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \
sampling_params.frequency_penalty
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \
sampling_params.presence_penalty
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \
sampling_params.repetition_penalty
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens
......@@ -244,6 +247,8 @@ class InputBatch:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
# Add request lora ID
if request.lora_request:
......@@ -284,6 +289,7 @@ class InputBatch:
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
return req_index
def clear(self) -> None:
......@@ -302,6 +308,7 @@ class InputBatch:
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
......@@ -332,8 +339,8 @@ class InputBatch:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_prompt_tokens[empty_index] = \
self.num_prompt_tokens[last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
......@@ -341,15 +348,15 @@ class InputBatch:
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \
self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \
self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \
self.repetition_penalties_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \
self.stop_token_ids[last_req_index]
self.stop_token_ids[empty_index] = self.stop_token_ids[
last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
......@@ -357,6 +364,8 @@ class InputBatch:
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
# Decrement last_req_index since it is now empty.
last_req_index -= 1
......@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True)
non_blocking=True,
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
......@@ -421,6 +433,7 @@ class InputBatch:
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:self.num_reqs],
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
......@@ -429,10 +442,11 @@ class InputBatch:
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory)
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = (
self.token_ids_cpu[:self.num_reqs, :max_prompt_len])
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment