Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6224a9f6
Unverified
Commit
6224a9f6
authored
Feb 14, 2025
by
Lu Fang
Committed by
GitHub
Feb 14, 2025
Browse files
Support logit_bias in v1 Sampler (#13079)
parent
085b7b2d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
200 additions
and
101 deletions
+200
-101
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+59
-12
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+80
-62
vllm/sampling_params.py
vllm/sampling_params.py
+3
-1
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+16
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+40
-26
No files found.
tests/v1/sample/test_sampler.py
View file @
6224a9f6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
List
[
Optional
[
Dict
[
int
,
float
]]]:
res
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_default_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
...
...
@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties
=
True
,
min_tokens
=
[],
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
)
return
fake_sampling_metadata
...
...
@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
Set
[
int
]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
min_tokens
:
List
[
int
]
=
[]
...
...
@@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size
:
int
,
vocab_size
:
int
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.
For each batch, a random subset of token IDs is selected from the
...
...
@@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
...
...
@@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
...
...
@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def
test_sampler_repetition_penalty
(
device
:
str
,
batch_size
:
int
,
repetition_penalty
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
...
...
@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
tests/v1/worker/test_gpu_input_batch.py
View file @
6224a9f6
...
...
@@ -45,9 +45,11 @@ def _remove_requests(
def
_construct_expected_sampling_metadata
(
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
)
->
SamplingMetadata
:
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
...
...
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
continue
...
...
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids
[
index_in_input_batch
]
=
req
.
prompt_token_ids
presence_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
presence_penalty
frequency_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
frequency_penalty
repetition_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
repetition_penalty
frequency_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
frequency_penalty
)
repetition_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
repetition_penalty
)
top_k
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_k
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop_token_ids
[
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop_token_ids
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_tokens
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
all_greedy
=
False
,
all_random
=
True
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
...
...
@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
,
pad
=
VOCAB_SIZE
,
device
=
torch
.
device
(
device
),
dtype
=
torch
.
int64
,
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
\
all
(
x
==
0
for
x
in
frequency_penalties
)
and
\
all
(
x
==
1
for
x
in
repetition_penalties
))
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
logit_bias
=
logit_bias
,
)
def
_create_sampling_params
():
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
stop_token_ids
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
10
))
])
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
stop_token_ids
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
10
))
],
logit_bias
=
{
0
:
np
.
random
.
uniform
(
-
3.0
,
3.0
)},
)
def
_construct_cached_request_state
(
req_id_suffix
:
int
):
...
...
@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
NUM_OUTPUT_TOKENS
))
]
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
)
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
)
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
reqs
:
List
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
...
...
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata
.
top_p
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
prompt_token_ids
)
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
assert
(
expected_sampling_metadata
.
min
_tokens
==
sampling_metadata
.
min_tokens
)
assert
(
expected_
sampling_metadata
.
stop_token_ids
==
sampling_metadata
.
stop_token_ids
)
assert
(
expected_
sampling_metadata
.
no_penalties
==
sampling_metadata
.
no_
penalties
)
assert
(
expected_sampling_metadata
.
no_top_
p
==
sampling_metadata
.
no_top_
p
)
assert
(
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
)
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
stop
_token
_id
s
==
\
sampling_metadata
.
stop_token_ids
assert
expected_
sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
assert
expected_
sampling_metadata
.
no_
top_p
==
sampling_metadata
.
no_top_p
assert
expected_sampling_metadata
.
no_top_
k
==
sampling_metadata
.
no_top_
k
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
vllm/sampling_params.py
View file @
6224a9f6
...
...
@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
"SamplingParams"
:
if
logit_bias
is
not
None
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
=
{
int
(
token
):
bias
int
(
token
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token
,
bias
in
logit_bias
.
items
()
}
...
...
vllm/v1/sample/metadata.py
View file @
6224a9f6
...
...
@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
vllm/v1/sample/sampler.py
View file @
6224a9f6
...
...
@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Apply logits bias.
logits
=
self
.
apply_logits_bias
(
logits
,
sampling_metadata
)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Apply temperature.
...
...
@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
return
logits
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
i
,
token_id
]
+=
bias
return
logits
vllm/v1/worker/gpu_input_batch.py
View file @
6224a9f6
...
...
@@ -130,7 +130,7 @@ class InputBatch:
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
frequency_penalties_cpu
=
\
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_cpu_tensor
.
numpy
()
self
.
frequency_penalties_reqs
:
Set
[
str
]
=
set
()
# Presence penalty related data structures
...
...
@@ -141,8 +141,8 @@ class InputBatch:
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
\
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_reqs
:
Set
[
str
]
=
set
()
# Repetition penalty related data structures
...
...
@@ -155,7 +155,7 @@ class InputBatch:
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
repetition_penalties_cpu
=
\
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_cpu_tensor
.
numpy
()
self
.
repetition_penalties_reqs
:
Set
[
str
]
=
set
()
self
.
min_tokens
:
List
[
int
]
=
[
0
]
*
max_num_reqs
...
...
@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
def
add_request
(
self
,
request
:
"CachedRequestState"
,
...
...
@@ -220,16 +223,16 @@ class InputBatch:
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
if
sampling_params
.
top_k
>
0
:
self
.
top_k_reqs
.
add
(
req_id
)
self
.
frequency_penalties_cpu
[
req_index
]
=
\
sampling_params
.
frequency_penalty
self
.
frequency_penalties_cpu
[
req_index
]
=
sampling_params
.
frequency_penalty
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
\
sampling_params
.
presence_penalty
self
.
presence_penalties_cpu
[
req_index
]
=
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
\
sampling_params
.
repetition_penalty
self
.
repetition_penalties_cpu
[
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
...
...
@@ -244,6 +247,8 @@ class InputBatch:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
# Add request lora ID
if
request
.
lora_request
:
...
...
@@ -284,6 +289,7 @@ class InputBatch:
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
return
req_index
def
clear
(
self
)
->
None
:
...
...
@@ -302,6 +308,7 @@ class InputBatch:
self
.
request_lora_mapping
.
fill
(
0
)
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
logit_bias
=
[
None
]
*
self
.
max_num_reqs
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
if
self
.
num_reqs
==
0
:
...
...
@@ -332,8 +339,8 @@ class InputBatch:
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_prompt_tokens
[
empty_index
]
=
\
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
...
...
@@ -341,15 +348,15 @@ class InputBatch:
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
\
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
\
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
\
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
\
self
.
stop_token_ids
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
self
.
stop_token_ids
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
...
...
@@ -357,6 +364,8 @@ class InputBatch:
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling.
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
...
...
@@ -421,6 +433,7 @@ class InputBatch:
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
self
.
num_reqs
],
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
...
...
@@ -429,10 +442,11 @@ class InputBatch:
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
(
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
)
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment