Unverified Commit 6224a9f6 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Support logit_bias in v1 Sampler (#13079)

parent 085b7b2d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import List, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
import pytest import pytest
...@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor( ...@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
) )
def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res
def _create_default_sampling_metadata( def _create_default_sampling_metadata(
num_output_tokens: int, num_output_tokens: int,
batch_size: int, batch_size: int,
...@@ -80,6 +92,7 @@ def _create_default_sampling_metadata( ...@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True, no_penalties=True,
min_tokens=[], min_tokens=[],
stop_token_ids=[], stop_token_ids=[],
logit_bias=[None] * batch_size,
) )
return fake_sampling_metadata return fake_sampling_metadata
...@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, ...@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens) penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \ assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens) non_penalized_token_id in output_tokens)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == pytest.approx(bias_value +
1e-2)
else:
assert logits_for_req[token_id] == pytest.approx(1e-2)
...@@ -45,9 +45,11 @@ def _remove_requests( ...@@ -45,9 +45,11 @@ def _remove_requests(
def _construct_expected_sampling_metadata( def _construct_expected_sampling_metadata(
reqs: List[CachedRequestState], req_ids_retained: Set[int], reqs: List[CachedRequestState],
req_ids_retained: Set[int],
req_id_index_in_input_batch: Dict[str, int], req_id_index_in_input_batch: Dict[str, int],
device: torch.device) -> SamplingMetadata: device: torch.device,
) -> SamplingMetadata:
""" """
Constructs and returns the expected SamplingMetadata for this Constructs and returns the expected SamplingMetadata for this
batch. batch.
...@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata( ...@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature = [0.0 for _ in range(num_reqs)] temperature = [0.0 for _ in range(num_reqs)]
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)] stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
min_tokens = [0 for _ in range(num_reqs)] min_tokens = [0 for _ in range(num_reqs)]
logit_bias = [None] * num_reqs
for req in reqs: for req in reqs:
if req.req_id not in req_ids_retained: if req.req_id not in req_ids_retained:
continue continue
...@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata( ...@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[ presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[ frequency_penalties[index_in_input_batch] = (
index_in_input_batch] = req.sampling_params.frequency_penalty req.sampling_params.frequency_penalty)
repetition_penalties[ repetition_penalties[index_in_input_batch] = (
index_in_input_batch] = req.sampling_params.repetition_penalty req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature temperature[index_in_input_batch] = req.sampling_params.temperature
stop_token_ids[ stop_token_ids[
index_in_input_batch] = req.sampling_params.all_stop_token_ids index_in_input_batch] = req.sampling_params.all_stop_token_ids
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float, device=device), temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False, all_greedy=False,
all_random=True, all_random=True,
top_p=torch.tensor(top_p, dtype=torch.float, device=device), top_p=torch.tensor(top_p, dtype=torch.float, device=device),
...@@ -93,32 +97,34 @@ def _construct_expected_sampling_metadata( ...@@ -93,32 +97,34 @@ def _construct_expected_sampling_metadata(
no_top_k=all(x == 0 for x in top_k), no_top_k=all(x == 0 for x in top_k),
generators={}, generators={},
max_num_logprobs=0, max_num_logprobs=0,
prompt_token_ids= make_tensor_with_pad( prompt_token_ids=make_tensor_with_pad(
prompt_token_ids, prompt_token_ids,
pad=VOCAB_SIZE, pad=VOCAB_SIZE,
device=torch.device(device), device=torch.device(device),
dtype=torch.int64, dtype=torch.int64,
), ),
frequency_penalties=torch.tensor( frequency_penalties=torch.tensor(frequency_penalties,
frequency_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
presence_penalties=torch.tensor( presence_penalties=torch.tensor(presence_penalties,
presence_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
repetition_penalties=torch.tensor( repetition_penalties=torch.tensor(repetition_penalties,
repetition_penalties, dtype=torch.float, dtype=torch.float,
device=device), device=device),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,
min_tokens=min_tokens, min_tokens=min_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
no_penalties=(all(x ==0 for x in presence_penalties) and \ no_penalties=(all(x == 0 for x in presence_penalties)
all(x ==0 for x in frequency_penalties) and \ and all(x == 0 for x in frequency_penalties)
all(x ==1 for x in repetition_penalties)) and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
) )
def _create_sampling_params(): def _create_sampling_params():
return SamplingParams(top_k=np.random.randint(1, 10), return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0), top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0), presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0), repetition_penalty=np.random.uniform(0.0, 2.0),
...@@ -127,7 +133,9 @@ def _create_sampling_params(): ...@@ -127,7 +133,9 @@ def _create_sampling_params():
stop_token_ids=[ stop_token_ids=[
np.random.randint(0, VOCAB_SIZE) np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10)) for _ in range(np.random.randint(10))
]) ],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int): def _construct_cached_request_state(req_id_suffix: int):
...@@ -139,7 +147,8 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -139,7 +147,8 @@ def _construct_cached_request_state(req_id_suffix: int):
np.random.randint(0, VOCAB_SIZE) np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
] ]
return CachedRequestState(req_id=f"req_id_{req_id_suffix}", return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt=None, prompt=None,
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
...@@ -148,7 +157,8 @@ def _construct_cached_request_state(req_id_suffix: int): ...@@ -148,7 +157,8 @@ def _construct_cached_request_state(req_id_suffix: int):
block_ids=[], block_ids=[],
generator=None, generator=None,
num_computed_tokens=len(output_token_ids), num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids) output_token_ids=output_token_ids,
)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
...@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness. results to ensure correctness.
""" """
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size, input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024, max_model_len=1024,
max_num_blocks_per_req=10, max_num_blocks_per_req=10,
device=torch.device(device), device=torch.device(device),
pin_memory=is_pin_memory_available(), pin_memory=is_pin_memory_available(),
vocab_size=1024) vocab_size=1024,
)
reqs: List[CachedRequestState] = [] reqs: List[CachedRequestState] = []
req_id_reqs = {} req_id_reqs = {}
req_id_output_token_ids = {} req_id_output_token_ids = {}
...@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): ...@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.top_p) sampling_metadata.top_p)
assert torch.allclose(expected_sampling_metadata.top_k, assert torch.allclose(expected_sampling_metadata.top_k,
sampling_metadata.top_k) sampling_metadata.top_k)
assert torch.allclose(expected_sampling_metadata.frequency_penalties, assert torch.allclose(
sampling_metadata.frequency_penalties) expected_sampling_metadata.frequency_penalties,
assert torch.allclose(expected_sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties,
sampling_metadata.presence_penalties) )
assert torch.allclose(expected_sampling_metadata.repetition_penalties, assert torch.allclose(
sampling_metadata.repetition_penalties) expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids, assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids) sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids == assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids) sampling_metadata.output_token_ids)
assert ( assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens) assert expected_sampling_metadata.stop_token_ids == \
assert (expected_sampling_metadata.stop_token_ids == sampling_metadata.stop_token_ids
sampling_metadata.stop_token_ids) assert expected_sampling_metadata.no_penalties == \
assert (expected_sampling_metadata.no_penalties == sampling_metadata.no_penalties
sampling_metadata.no_penalties) assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p) assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k) assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
...@@ -243,8 +243,10 @@ class SamplingParams( ...@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids: Optional[List[int]] = None, allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None: if logit_bias is not None:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias = { logit_bias = {
int(token): bias int(token): min(100.0, max(-100.0, bias))
for token, bias in logit_bias.items() for token, bias in logit_bias.items()
} }
......
...@@ -32,3 +32,5 @@ class SamplingMetadata: ...@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]] output_token_ids: List[List[int]]
min_tokens: List[int] min_tokens: List[int]
stop_token_ids: List[Set[int]] stop_token_ids: List[Set[int]]
logit_bias: List[Optional[Dict[int, float]]]
...@@ -37,6 +37,8 @@ class Sampler(nn.Module): ...@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits. # Use float32 for the logits.
logits = logits.to(torch.float32) logits = logits.to(torch.float32)
# Apply logits bias.
logits = self.apply_logits_bias(logits, sampling_metadata)
# Apply penalties (e.g., min_tokens, freq_penalties). # Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata) logits = self.apply_penalties(logits, sampling_metadata)
# Apply temperature. # Apply temperature.
...@@ -166,3 +168,17 @@ class Sampler(nn.Module): ...@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata.repetition_penalties, sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids) sampling_metadata.output_token_ids)
return logits return logits
def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits
...@@ -141,8 +141,8 @@ class InputBatch: ...@@ -141,8 +141,8 @@ class InputBatch:
dtype=torch.float, dtype=torch.float,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory)
self.presence_penalties_cpu = \ self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
self.presence_penalties_cpu_tensor.numpy() )
self.presence_penalties_reqs: Set[str] = set() self.presence_penalties_reqs: Set[str] = set()
# Repetition penalty related data structures # Repetition penalty related data structures
...@@ -180,6 +180,9 @@ class InputBatch: ...@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase. # that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {} self.num_prompt_logprobs: Dict[str, int] = {}
self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs
def add_request( def add_request(
self, self,
request: "CachedRequestState", request: "CachedRequestState",
...@@ -220,16 +223,16 @@ class InputBatch: ...@@ -220,16 +223,16 @@ class InputBatch:
self.top_k_cpu[req_index] = sampling_params.top_k self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0: if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id) self.top_k_reqs.add(req_id)
self.frequency_penalties_cpu[req_index] = \ self.frequency_penalties_cpu[
sampling_params.frequency_penalty req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0: if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id) self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[req_index] = \ self.presence_penalties_cpu[
sampling_params.presence_penalty req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0: if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id) self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[req_index] = \ self.repetition_penalties_cpu[
sampling_params.repetition_penalty req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0: if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id) self.repetition_penalties_reqs.add(req_id)
self.min_tokens[req_index] = sampling_params.min_tokens self.min_tokens[req_index] = sampling_params.min_tokens
...@@ -244,6 +247,8 @@ class InputBatch: ...@@ -244,6 +247,8 @@ class InputBatch:
self.num_logprobs[req_id] = sampling_params.logprobs self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None: if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
# Add request lora ID # Add request lora ID
if request.lora_request: if request.lora_request:
...@@ -284,6 +289,7 @@ class InputBatch: ...@@ -284,6 +289,7 @@ class InputBatch:
self.lora_id_to_lora_request.pop(lora_id) self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0 self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
return req_index return req_index
def clear(self) -> None: def clear(self) -> None:
...@@ -302,6 +308,7 @@ class InputBatch: ...@@ -302,6 +308,7 @@ class InputBatch:
self.request_lora_mapping.fill(0) self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear() self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear() self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs
def condense(self, empty_req_indices: List[int]) -> None: def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0: if self.num_reqs == 0:
...@@ -332,8 +339,8 @@ class InputBatch: ...@@ -332,8 +339,8 @@ class InputBatch:
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens] last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens self.num_tokens[empty_index] = num_tokens
self.num_prompt_tokens[empty_index] = \ self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
self.num_prompt_tokens[last_req_index] last_req_index]
self.num_computed_tokens_cpu[ self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index] empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index) self.block_table.move_row(last_req_index, empty_index)
...@@ -341,15 +348,15 @@ class InputBatch: ...@@ -341,15 +348,15 @@ class InputBatch:
last_req_index] last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[empty_index] = \ self.frequency_penalties_cpu[
self.frequency_penalties_cpu[last_req_index] empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[empty_index] = \ self.presence_penalties_cpu[
self.presence_penalties_cpu[last_req_index] empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[empty_index] = \ self.repetition_penalties_cpu[
self.repetition_penalties_cpu[last_req_index] empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index] self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = \ self.stop_token_ids[empty_index] = self.stop_token_ids[
self.stop_token_ids[last_req_index] last_req_index]
generator = self.generators.pop(last_req_index, None) generator = self.generators.pop(last_req_index, None)
if generator is not None: if generator is not None:
self.generators[empty_index] = generator self.generators[empty_index] = generator
...@@ -357,6 +364,8 @@ class InputBatch: ...@@ -357,6 +364,8 @@ class InputBatch:
self.request_lora_mapping[empty_index] = self.request_lora_mapping[ self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index] last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
...@@ -378,13 +387,16 @@ class InputBatch: ...@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling. # penalties to be applied during sampling.
self.frequency_penalties[:self.num_reqs].copy_( self.frequency_penalties[:self.num_reqs].copy_(
self.frequency_penalties_cpu_tensor[:self.num_reqs], self.frequency_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
self.presence_penalties[:self.num_reqs].copy_( self.presence_penalties[:self.num_reqs].copy_(
self.presence_penalties_cpu_tensor[:self.num_reqs], self.presence_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
self.repetition_penalties[:self.num_reqs].copy_( self.repetition_penalties[:self.num_reqs].copy_(
self.repetition_penalties_cpu_tensor[:self.num_reqs], self.repetition_penalties_cpu_tensor[:self.num_reqs],
non_blocking=True) non_blocking=True,
)
# The prompt tokens are used only for applying penalties during # The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when # the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied. # there are requests which need penalties to be applied.
...@@ -421,6 +433,7 @@ class InputBatch: ...@@ -421,6 +433,7 @@ class InputBatch:
min_tokens=self.min_tokens[:self.num_reqs], min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:self.num_reqs],
) )
def _make_prompt_token_ids_tensor(self) -> torch.Tensor: def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
...@@ -429,10 +442,11 @@ class InputBatch: ...@@ -429,10 +442,11 @@ class InputBatch:
(self.num_reqs, max_prompt_len), (self.num_reqs, max_prompt_len),
device="cpu", device="cpu",
dtype=torch.int64, dtype=torch.int64,
pin_memory=self.pin_memory) pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = ( prompt_token_ids[:] = self.token_ids_cpu[:self.
self.token_ids_cpu[:self.num_reqs, :max_prompt_len]) num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a # Use the value of vocab_size as a pad since we don't have a
# token_id of this value. # token_id of this value.
for i in range(self.num_reqs): for i in range(self.num_reqs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment