Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6224a9f6
Unverified
Commit
6224a9f6
authored
Feb 14, 2025
by
Lu Fang
Committed by
GitHub
Feb 14, 2025
Browse files
Support logit_bias in v1 Sampler (#13079)
parent
085b7b2d
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
200 additions
and
101 deletions
+200
-101
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+59
-12
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+80
-62
vllm/sampling_params.py
vllm/sampling_params.py
+3
-1
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+2
-0
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+16
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+40
-26
No files found.
tests/v1/sample/test_sampler.py
View file @
6224a9f6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
List
[
Optional
[
Dict
[
int
,
float
]]]:
res
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_default_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
...
...
@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties
=
True
,
min_tokens
=
[],
stop_token_ids
=
[],
logit_bias
=
[
None
]
*
batch_size
,
)
return
fake_sampling_metadata
...
...
@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
tests/v1/worker/test_gpu_input_batch.py
View file @
6224a9f6
...
...
@@ -45,9 +45,11 @@ def _remove_requests(
def
_construct_expected_sampling_metadata
(
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
)
->
SamplingMetadata
:
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
...
...
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop_token_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
continue
...
...
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
prompt_token_ids
[
index_in_input_batch
]
=
req
.
prompt_token_ids
presence_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
presence_penalty
frequency_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
frequency_penalty
repetition_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
repetition_penalty
frequency_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
frequency_penalty
)
repetition_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
repetition_penalty
)
top_k
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_k
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop_token_ids
[
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop_token_ids
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_tokens
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
all_greedy
=
False
,
all_random
=
True
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
...
...
@@ -93,32 +97,34 @@ def _construct_expected_sampling_metadata(
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
,
pad
=
VOCAB_SIZE
,
device
=
torch
.
device
(
device
),
dtype
=
torch
.
int64
,
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
\
all
(
x
==
0
for
x
in
frequency_penalties
)
and
\
all
(
x
==
1
for
x
in
repetition_penalties
))
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
frequency_penalties
)
and
all
(
x
==
1
for
x
in
repetition_penalties
)),
logit_bias
=
logit_bias
,
)
def
_create_sampling_params
():
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
...
...
@@ -127,7 +133,9 @@ def _create_sampling_params():
stop_token_ids
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
10
))
])
],
logit_bias
=
{
0
:
np
.
random
.
uniform
(
-
3.0
,
3.0
)},
)
def
_construct_cached_request_state
(
req_id_suffix
:
int
):
...
...
@@ -139,7 +147,8 @@ def _construct_cached_request_state(req_id_suffix: int):
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
NUM_OUTPUT_TOKENS
))
]
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
...
...
@@ -148,7 +157,8 @@ def _construct_cached_request_state(req_id_suffix: int):
block_ids
=
[],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
)
output_token_ids
=
output_token_ids
,
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
)
vocab_size
=
1024
,
)
reqs
:
List
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
...
...
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata
.
top_p
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
prompt_token_ids
)
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
assert
(
expected_sampling_metadata
.
min
_tokens
==
sampling_metadata
.
min_tokens
)
assert
(
expected_
sampling_metadata
.
stop_token_ids
==
sampling_metadata
.
stop_token_ids
)
assert
(
expected_
sampling_metadata
.
no_penalties
==
sampling_metadata
.
no_
penalties
)
assert
(
expected_sampling_metadata
.
no_top_
p
==
sampling_metadata
.
no_top_
p
)
assert
(
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
)
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
stop
_token
_id
s
==
\
sampling_metadata
.
stop_token_ids
assert
expected_
sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
assert
expected_
sampling_metadata
.
no_
top_p
==
sampling_metadata
.
no_top_p
assert
expected_sampling_metadata
.
no_top_
k
==
sampling_metadata
.
no_top_
k
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
vllm/sampling_params.py
View file @
6224a9f6
...
...
@@ -243,8 +243,10 @@ class SamplingParams(
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
"SamplingParams"
:
if
logit_bias
is
not
None
:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
logit_bias
=
{
int
(
token
):
bias
int
(
token
):
min
(
100.0
,
max
(
-
100.0
,
bias
))
for
token
,
bias
in
logit_bias
.
items
()
}
...
...
vllm/v1/sample/metadata.py
View file @
6224a9f6
...
...
@@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
vllm/v1/sample/sampler.py
View file @
6224a9f6
...
...
@@ -37,6 +37,8 @@ class Sampler(nn.Module):
# Use float32 for the logits.
logits
=
logits
.
to
(
torch
.
float32
)
# Apply logits bias.
logits
=
self
.
apply_logits_bias
(
logits
,
sampling_metadata
)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits
=
self
.
apply_penalties
(
logits
,
sampling_metadata
)
# Apply temperature.
...
...
@@ -166,3 +168,17 @@ class Sampler(nn.Module):
sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
output_token_ids
)
return
logits
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for
i
,
logit_bias
in
enumerate
(
sampling_metadata
.
logit_bias
):
if
logit_bias
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
i
,
token_id
]
+=
bias
return
logits
vllm/v1/worker/gpu_input_batch.py
View file @
6224a9f6
...
...
@@ -141,8 +141,8 @@ class InputBatch:
dtype
=
torch
.
float
,
device
=
"cpu"
,
pin_memory
=
pin_memory
)
self
.
presence_penalties_cpu
=
\
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_cpu
=
self
.
presence_penalties_cpu_tensor
.
numpy
(
)
self
.
presence_penalties_reqs
:
Set
[
str
]
=
set
()
# Repetition penalty related data structures
...
...
@@ -180,6 +180,9 @@ class InputBatch:
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
Dict
[
str
,
int
]
=
{}
self
.
logit_bias
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
def
add_request
(
self
,
request
:
"CachedRequestState"
,
...
...
@@ -220,16 +223,16 @@ class InputBatch:
self
.
top_k_cpu
[
req_index
]
=
sampling_params
.
top_k
if
sampling_params
.
top_k
>
0
:
self
.
top_k_reqs
.
add
(
req_id
)
self
.
frequency_penalties_cpu
[
req_index
]
=
\
sampling_params
.
frequency_penalty
self
.
frequency_penalties_cpu
[
req_index
]
=
sampling_params
.
frequency_penalty
if
sampling_params
.
frequency_penalty
!=
0.0
:
self
.
frequency_penalties_reqs
.
add
(
req_id
)
self
.
presence_penalties_cpu
[
req_index
]
=
\
sampling_params
.
presence_penalty
self
.
presence_penalties_cpu
[
req_index
]
=
sampling_params
.
presence_penalty
if
sampling_params
.
presence_penalty
!=
0.0
:
self
.
presence_penalties_reqs
.
add
(
req_id
)
self
.
repetition_penalties_cpu
[
req_index
]
=
\
sampling_params
.
repetition_penalty
self
.
repetition_penalties_cpu
[
req_index
]
=
sampling_params
.
repetition_penalty
if
sampling_params
.
repetition_penalty
!=
1.0
:
self
.
repetition_penalties_reqs
.
add
(
req_id
)
self
.
min_tokens
[
req_index
]
=
sampling_params
.
min_tokens
...
...
@@ -244,6 +247,8 @@ class InputBatch:
self
.
num_logprobs
[
req_id
]
=
sampling_params
.
logprobs
if
sampling_params
.
prompt_logprobs
is
not
None
:
self
.
num_prompt_logprobs
[
req_id
]
=
sampling_params
.
prompt_logprobs
if
sampling_params
.
logit_bias
is
not
None
:
self
.
logit_bias
[
req_index
]
=
sampling_params
.
logit_bias
# Add request lora ID
if
request
.
lora_request
:
...
...
@@ -284,6 +289,7 @@ class InputBatch:
self
.
lora_id_to_lora_request
.
pop
(
lora_id
)
self
.
request_lora_mapping
[
req_index
]
=
0
self
.
logit_bias
[
req_index
]
=
None
return
req_index
def
clear
(
self
)
->
None
:
...
...
@@ -302,6 +308,7 @@ class InputBatch:
self
.
request_lora_mapping
.
fill
(
0
)
self
.
lora_id_to_lora_request
.
clear
()
self
.
lora_id_to_request_ids
.
clear
()
self
.
logit_bias
=
[
None
]
*
self
.
max_num_reqs
def
condense
(
self
,
empty_req_indices
:
List
[
int
])
->
None
:
if
self
.
num_reqs
==
0
:
...
...
@@ -332,8 +339,8 @@ class InputBatch:
self
.
token_ids_cpu
[
empty_index
,
:
num_tokens
]
=
self
.
token_ids_cpu
[
last_req_index
,
:
num_tokens
]
self
.
num_tokens
[
empty_index
]
=
num_tokens
self
.
num_prompt_tokens
[
empty_index
]
=
\
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_prompt_tokens
[
empty_index
]
=
self
.
num_prompt_tokens
[
last_req_index
]
self
.
num_computed_tokens_cpu
[
empty_index
]
=
self
.
num_computed_tokens_cpu
[
last_req_index
]
self
.
block_table
.
move_row
(
last_req_index
,
empty_index
)
...
...
@@ -341,15 +348,15 @@ class InputBatch:
last_req_index
]
self
.
top_p_cpu
[
empty_index
]
=
self
.
top_p_cpu
[
last_req_index
]
self
.
top_k_cpu
[
empty_index
]
=
self
.
top_k_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
\
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
\
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
\
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
frequency_penalties_cpu
[
empty_index
]
=
self
.
frequency_penalties_cpu
[
last_req_index
]
self
.
presence_penalties_cpu
[
empty_index
]
=
self
.
presence_penalties_cpu
[
last_req_index
]
self
.
repetition_penalties_cpu
[
empty_index
]
=
self
.
repetition_penalties_cpu
[
last_req_index
]
self
.
min_tokens
[
empty_index
]
=
self
.
min_tokens
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
\
self
.
stop_token_ids
[
last_req_index
]
self
.
stop_token_ids
[
empty_index
]
=
self
.
stop_token_ids
[
last_req_index
]
generator
=
self
.
generators
.
pop
(
last_req_index
,
None
)
if
generator
is
not
None
:
self
.
generators
[
empty_index
]
=
generator
...
...
@@ -357,6 +364,8 @@ class InputBatch:
self
.
request_lora_mapping
[
empty_index
]
=
self
.
request_lora_mapping
[
last_req_index
]
self
.
logit_bias
[
empty_index
]
=
self
.
logit_bias
[
last_req_index
]
# Decrement last_req_index since it is now empty.
last_req_index
-=
1
...
...
@@ -378,13 +387,16 @@ class InputBatch:
# penalties to be applied during sampling.
self
.
frequency_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
frequency_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
presence_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
presence_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
self
.
repetition_penalties
[:
self
.
num_reqs
].
copy_
(
self
.
repetition_penalties_cpu_tensor
[:
self
.
num_reqs
],
non_blocking
=
True
)
non_blocking
=
True
,
)
# The prompt tokens are used only for applying penalties during
# the sampling process. Hence copy these tensors only when
# there are requests which need penalties to be applied.
...
...
@@ -421,6 +433,7 @@ class InputBatch:
min_tokens
=
self
.
min_tokens
[:
self
.
num_reqs
],
stop_token_ids
=
self
.
stop_token_ids
[:
self
.
num_reqs
],
no_penalties
=
self
.
no_penalties
,
logit_bias
=
self
.
logit_bias
[:
self
.
num_reqs
],
)
def
_make_prompt_token_ids_tensor
(
self
)
->
torch
.
Tensor
:
...
...
@@ -429,10 +442,11 @@ class InputBatch:
(
self
.
num_reqs
,
max_prompt_len
),
device
=
"cpu"
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
)
pin_memory
=
self
.
pin_memory
,
)
prompt_token_ids
=
prompt_token_ids_cpu_tensor
.
numpy
()
prompt_token_ids
[:]
=
(
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
)
prompt_token_ids
[:]
=
self
.
token_ids_cpu
[:
self
.
num_reqs
,
:
max_prompt_len
]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for
i
in
range
(
self
.
num_reqs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment