Unverified Commit 6224a9f6 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Support logit_bias in v1 Sampler (#13079)

parent 085b7b2d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import pytest import pytest
...@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor( ...@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
) )
def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res
def _create_default_sampling_metadata( def _create_default_sampling_metadata(
num_output_tokens: int, num_output_tokens: int,
batch_size: int, batch_size: int,
...@@ -80,6 +92,7 @@ def _create_default_sampling_metadata( ...@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True, no_penalties=True,
min_tokens=[], min_tokens=[],
stop_token_ids=[], stop_token_ids=[],
logit_bias=[None] * batch_size,
) )
return fake_sampling_metadata return fake_sampling_metadata
...@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens( ...@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty: List[int] batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]: ) -> Tuple[List[int], List[Set[int]]]:
""" """
Generates and returns a list of minimum token penalties (`min_tokens`) Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch. batch.
If a batch index is included in `batch_indices_for_min_token_penalty`, If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range), a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty. `min_tokens` value is assigned, and the stop token IDs set is empty.
""" """
stop_token_ids: List[Set[int]] = [] stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = [] min_tokens: List[int] = []
...@@ -120,7 +133,7 @@ def _create_weighted_output_token_list( ...@@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size: int, batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
""" """
Creates an output token list where each token occurs a distinct Creates an output token list where each token occurs a distinct
number of times. number of times.
For each batch, a random subset of token IDs is selected from the For each batch, a random subset of token IDs is selected from the
...@@ -129,8 +142,8 @@ def _create_weighted_output_token_list( ...@@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
Returns: Returns:
Tuple[List[List[int]], List[List[int]]]: Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist - The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted corresponds to a batch and contains tokens with weighted
frequencies. frequencies.
- The second element is a list of distinct token IDs for each - The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output batch, ordered by their frequency in the corresponding output
...@@ -155,7 +168,7 @@ def _create_weighted_output_token_list( ...@@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int): def test_sampler_min_tokens_penalty(device: str, batch_size: int):
""" """
Tests that if the number of output tokens is less than Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf. the stop token ids to -inf.
""" """
...@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, ...@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def test_sampler_repetition_penalty(device: str, batch_size: int, def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float): repetition_penalty: float):
""" """
Test to verify that when the repetition penalty is enabled, tokens Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing are penalized based on their presence in the prompt or the existing
output. output.
""" """
...@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, ...@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens) penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \ assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens) non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
...@@ -45,9 +45,11 @@ def _remove_requests( ...@@ -45,9 +45,11 @@ def _remove_requests(
def _construct_expected_sampling_metadata( def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState], req_ids_retained: Set[int], reqs: List[CachedRequestState],
req_id_index_in_input_batch: Dict[str, int], req_ids_retained: Set[int],
device: torch.device) -> SamplingMetadata: req_id_index_in_input_batch: Dict[str, int],
device: torch.device,
) -> SamplingMetadata:
""" """
Constructs and returns the expected SamplingMetadata for this Constructs and returns the expected SamplingMetadata for this
batch. batch.
...@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata( ...@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)] min_tokens = [0 for _ in range(num_reqs)]
logit_bias = [None] * num_reqs
for req in reqs: for req in reqs:
if req.req_id not in req_ids_retained: if req.req_id not in req_ids_retained:
continue continue
...@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata( ...@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[ presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[ frequency_penalties[index_in_input_batch] = (
index_in_input_batch] = req.sampling_params.frequency_penalty req.sampling_params.frequency_penalty)
repetition_penalties[ repetition_penalties[index_in_input_batch] = (
index_in_input_batch] = req.sampling_params.repetition_penalty req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[ stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device), temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False, all_greedy=False,
all_random=True, all_random=True,
top_p=torch.tensor(top_p, dtype=torch.float, device=device), top_p=torch.tensor(top_p, dtype=torch.float, device=device),
...@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata( ...@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
no_top_k=all(x == 0 for x in top_k), no_top_k=all(x == 0 for x in top_k),
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
prompt_token_ids= make_tensor_with_pad( prompt_token_ids=make_tensor_with_pad(
prompt_token_ids, prompt_token_ids,
pad=VOCAB_SIZE, pad=VOCAB_SIZE,
device=torch.device(device), device=torch.device(device),
dtype=torch.int64, dtype=torch.int64,
), ),
frequency_penalties=torch.tensor( frequency_penalties=torch.tensor(frequency_penalties,
frequency_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
presence_penalties=torch.tensor( presence_penalties=torch.tensor(presence_penalties,
presence_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
repetition_penalties=torch.tensor( repetition_penalties=torch.tensor(repetition_penalties,
repetition_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
min_tokens=min_tokens, min_tokens=min_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
no_penalties=(all(x ==0 for x in presence_penalties) and \ no_penalties=(all(x == 0 for x in presence_penalties)
all(x ==0 for x in frequency_penalties) and \ and all(x == 0 for x in frequency_penalties)
all(x ==1 for x in repetition_penalties)) and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
) )
def _create_sampling_params(): def _create_sampling_params():
return SamplingParams(top_k=np.random.randint(1, 10), return SamplingParams(
top_p=np.random.uniform(0.0, 1.0), top_k=np.random.randint(1, 10),
presence_penalty=np.random.uniform(-2.0, 2.0), top_p=np.random.uniform(0.0, 1.0),
repetition_penalty=np.random.uniform(0.0, 2.0), presence_penalty=np.random.uniform(-2.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0), repetition_penalty=np.random.uniform(0.0, 2.0),
min_tokens=np.random.randint(1, 10), frequency_penalty=np.random.uniform(-2.0, 2.0),
stop_token_ids=[ min_tokens=np.random.randint(1, 10),
np.random.randint(0, VOCAB_SIZE) stop_token_ids=[
for _ in range(np.random.randint(10)) np.random.randint(0, VOCAB_SIZE)
]) for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int): def _construct_cached_request_state(req_id_suffix: int):
...@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE) np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
] ]
return CachedRequestState(req_id=f"req_id_{req_id_suffix}", return CachedRequestState(
prompt_token_ids=prompt_token_ids, req_id=f"req_id_{req_id_suffix}",
prompt=None, prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(), prompt=None,
mm_inputs=[], sampling_params=_create_sampling_params(),
mm_positions=[], mm_inputs=[],
block_ids=[], mm_positions=[],
generator=None, block_ids=[],
num_computed_tokens=len(output_token_ids), generator=None,
output_token_ids=output_token_ids) num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
...@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness. results to ensure correctness.
""" """
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size, input_batch: InputBatch = InputBatch(
max_model_len=1024, max_num_reqs=batch_size,
max_num_blocks_per_req=10, max_model_len=1024,
device=torch.device(device), max_num_blocks_per_req=10,
pin_memory=is_pin_memory_available(), device=torch.device(device),
vocab_size=1024) pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: List[CachedRequestState] = [] reqs: List[CachedRequestState] = []
req_id_reqs = {} req_id_reqs = {}
req_id_output_token_ids = {} req_id_output_token_ids = {}
...@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.top_p) sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k, assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k) sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.frequency_penalties, assert torch.allclose(
sampling_metadata.frequency_penalties) expected_sampling_metadata.frequency_penalties,
assert torch.allclose(expected_sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties) )
assert torch.allclose(expected_sampling_metadata.repetition_penalties, assert torch.allclose(
sampling_metadata.repetition_penalties) expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids, assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids) sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids == assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids) sampling_metadata.output_token_ids)
assert ( assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) assert expected_sampling_metadata.stop_token_ids == \
assert (expected_sampling_metadata.stop_token_ids == sampling_metadata.stop_token_ids
sampling_metadata.stop_token_ids) assert expected_sampling_metadata.no_penalties == \
assert (expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties
sampling_metadata.no_penalties) assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
...@@ -243,8 +243,10 @@ class SamplingParams( ...@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids: Optional[List[int]] = None, allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None: if logit_bias is not None:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias = { logit_bias = {
int(token): bias int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items() for token, bias in logit_bias.items()
} }
......
...@@ -32,3 +32,5 @@ class SamplingMetadata: ...@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]] output_token_ids: List[List[int]]
min_tokens: List[int] min_tokens: List[int]
stop_token_ids: List[Set[int]] stop_token_ids: List[Set[int]]
logit_bias: List[Optional[Dict[int, float]]]
...@@ -37,6 +37,8 @@ class Sampler(nn.Module): ...@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits. # Use float32 for the logits.
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties). # Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata) logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature. # Apply temperature.
...@@ -166,3 +168,17 @@ class Sampler(nn.Module): ...@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata.repetition_penalties, sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids) sampling_metadata.output_token_ids)
return logits return logits
def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits
...@@ -130,7 +130,7 @@ class InputBatch: ...@@ -130,7 +130,7 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.frequency_penalties_cpu = \ self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy() self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: Set[str] = set() self.frequency_penalties_reqs: Set[str] = set()
# Presence penalty related data structures # Presence penalty related data structures
...@@ -141,8 +141,8 @@ class InputBatch: ...@@ -141,8 +141,8 @@ class InputBatch:
dtype=torch.float, dtype=torch.float,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.presence_penalties_cpu = \ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
self.presence_penalties_cpu_tensor.numpy() )
self.presence_penalties_reqs: Set[str] = set() self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures # Repetition penalty related data structures
...@@ -155,7 +155,7 @@ class InputBatch: ...@@ -155,7 +155,7 @@ class InputBatch:
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.repetition_penalties_cpu = \ self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy() self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: Set[str] = set() self.repetition_penalties_reqs: Set[str] = set()
self.min_tokens: List[int] = [0] * max_num_reqs self.min_tokens: List[int] = [0] * max_num_reqs
...@@ -180,6 +180,9 @@ class InputBatch: ...@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase. # that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {} self.num_prompt_logprobs: Dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
def add_request( def add_request(
self, self,
request: "CachedRequestState", request: "CachedRequestState",
...@@ -220,16 +223,16 @@ class InputBatch: ...@@ -220,16 +223,16 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0: if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id) self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \ self.frequency_penalties_cpu[
sampling_params.frequency_penalty req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0: if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id) self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \ self.presence_penalties_cpu[
sampling_params.presence_penalty req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0: if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id) self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \ self.repetition_penalties_cpu[
sampling_params.repetition_penalty req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0: if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id) self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens self.min_tokens[req_index] = sampling_params.min_tokens
...@@ -244,6 +247,8 @@ class InputBatch: ...@@ -244,6 +247,8 @@ class InputBatch:
self.num_logprobs[req_id] = sampling_params.logprobs self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
...@@ -284,6 +289,7 @@ class InputBatch: ...@@ -284,6 +289,7 @@ class InputBatch:
self.lora_id_to_lora_request.pop(lora_id) self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0 self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
return req_index return req_index
def clear(self) -> None: def clear(self) -> None:
...@@ -302,6 +308,7 @@ class InputBatch: ...@@ -302,6 +308,7 @@ class InputBatch:
self.request_lora_mapping.fill(0) self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear() self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear() self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
def condense(self, empty_req_indices: List[int]) -> None: def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0: if self.num_reqs == 0:
...@@ -332,8 +339,8 @@ class InputBatch: ...@@ -332,8 +339,8 @@ class InputBatch:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens] last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens self.num_tokens[empty_index] = num_tokens
self.num_prompt_tokens[empty_index] = \ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
self.num_prompt_tokens[last_req_index] last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
...@@ -341,15 +348,15 @@ class InputBatch: ...@@ -341,15 +348,15 @@ class InputBatch:
last_req_index] last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \ self.frequency_penalties_cpu[
self.frequency_penalties_cpu[last_req_index] empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \ self.presence_penalties_cpu[
self.presence_penalties_cpu[last_req_index] empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \ self.repetition_penalties_cpu[
self.repetition_penalties_cpu[last_req_index] empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index] self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \ self.stop_token_ids[empty_index] = self.stop_token_ids[
self.stop_token_ids[last_req_index] last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
...@@ -357,6 +364,8 @@ class InputBatch: ...@@ -357,6 +364,8 @@ class InputBatch:
self.request_lora_mapping[empty_index] = self.request_lora_mapping[ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index] last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
...@@ -378,13 +387,16 @@ class InputBatch: ...@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling. # penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_( self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs], self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
self.presence_penalties[:self.num_reqs].copy_( self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs], self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
self.repetition_penalties[:self.num_reqs].copy_( self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs], self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
# The prompt tokens are used only for applying penalties during # The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when # the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied. # there are requests which need penalties to be applied.
...@@ -421,6 +433,7 @@ class InputBatch: ...@@ -421,6 +433,7 @@ class InputBatch:
min_tokens=self.min_tokens[:self.num_reqs], min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:self.num_reqs],
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
...@@ -429,10 +442,11 @@ class InputBatch: ...@@ -429,10 +442,11 @@ class InputBatch:
(self.num_reqs, max_prompt_len), (self.num_reqs, max_prompt_len),
device="cpu", device="cpu",
dtype=torch.int64, dtype=torch.int64,
pin_memory=self.pin_memory) pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = ( prompt_token_ids[:] = self.token_ids_cpu[:self.
self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a # Use the value of vocab_size as a pad since we don't have a
# token_id of this value. # token_id of this value.
for i in range(self.num_reqs): for i in range(self.num_reqs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment