Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cbc2ef55
Unverified
Commit
cbc2ef55
authored
Oct 10, 2024
by
youkaichao
Committed by
GitHub
Oct 10, 2024
Browse files
[misc] hide best_of from engine (#9261)
Co-authored-by:
Brendan Wong
<
bjwpokemon@gmail.com
>
parent
94bf9ae4
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
46 additions
and
73 deletions
+46
-73
tests/entrypoints/openai/test_metrics.py
tests/entrypoints/openai/test_metrics.py
+0
-4
tests/metrics/test_metrics.py
tests/metrics/test_metrics.py
+0
-1
tests/tracing/test_tracing.py
tests/tracing/test_tracing.py
+0
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-9
vllm/engine/metrics.py
vllm/engine/metrics.py
+0
-8
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+0
-1
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+1
-1
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+8
-9
vllm/outputs.py
vllm/outputs.py
+1
-1
vllm/sampling_params.py
vllm/sampling_params.py
+17
-16
vllm/sequence.py
vllm/sequence.py
+5
-5
vllm/tracing.py
vllm/tracing.py
+0
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+11
-12
No files found.
tests/entrypoints/openai/test_metrics.py
View file @
cbc2ef55
...
...
@@ -70,7 +70,6 @@ EXPECTED_VALUES = {
[(
"_sum"
,
_NUM_REQUESTS
*
_NUM_GENERATION_TOKENS_PER_REQUEST
),
(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_n"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:request_params_best_of"
:
[(
"_count"
,
_NUM_REQUESTS
)],
"vllm:prompt_tokens"
:
[(
"_total"
,
_NUM_REQUESTS
*
_NUM_PROMPT_TOKENS_PER_REQUEST
)],
"vllm:generation_tokens"
:
...
...
@@ -151,9 +150,6 @@ EXPECTED_METRICS = [
"vllm:request_params_n_sum"
,
"vllm:request_params_n_bucket"
,
"vllm:request_params_n_count"
,
"vllm:request_params_best_of_sum"
,
"vllm:request_params_best_of_bucket"
,
"vllm:request_params_best_of_count"
,
"vllm:num_preemptions_total"
,
"vllm:prompt_tokens_total"
,
"vllm:generation_tokens_total"
,
...
...
tests/metrics/test_metrics.py
View file @
cbc2ef55
...
...
@@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
"vllm:e2e_request_latency_seconds"
,
"vllm:request_prompt_tokens"
,
"vllm:request_generation_tokens"
,
"vllm:request_params_best_of"
,
"vllm:request_params_n"
,
]
for
metric_name
in
request_histogram_metrics
:
...
...
tests/tracing/test_tracing.py
View file @
cbc2ef55
...
...
@@ -98,8 +98,6 @@ def test_traces(trace_service):
SpanAttributes
.
LLM_REQUEST_TOP_P
)
==
sampling_params
.
top_p
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
)
==
sampling_params
.
max_tokens
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
)
==
sampling_params
.
best_of
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_N
)
==
sampling_params
.
n
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_PROMPT_TOKENS
)
==
len
(
outputs
[
0
].
prompt_token_ids
)
...
...
@@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
SpanAttributes
.
LLM_REQUEST_TOP_P
)
==
sampling_params
.
top_p
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
)
==
sampling_params
.
max_tokens
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
)
==
sampling_params
.
best_of
assert
attributes
.
get
(
SpanAttributes
.
LLM_REQUEST_N
)
==
sampling_params
.
n
assert
attributes
.
get
(
SpanAttributes
.
LLM_USAGE_PROMPT_TOKENS
)
==
len
(
outputs
[
0
].
prompt_token_ids
)
...
...
vllm/core/scheduler.py
View file @
cbc2ef55
...
...
@@ -1205,7 +1205,7 @@ class Scheduler:
# async_output_proc is allowed only when we have a single sequence
# in the sequence group
no_single_seq
=
seq_group
.
sampling_params
is
None
or
(
seq_group
.
sampling_params
.
best_of
==
1
)
seq_group
.
sampling_params
.
n
==
1
)
return
no_single_seq
def
schedule
(
...
...
vllm/engine/llm_engine.py
View file @
cbc2ef55
...
...
@@ -767,7 +767,7 @@ class LLMEngine:
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `
best_of
` number of :class:`~vllm.Sequence` objects.
- Create `
n
` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
...
...
@@ -1242,8 +1242,7 @@ class LLMEngine:
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)"
)
" (i.e sampling_params.n == 1)"
)
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
...
...
@@ -1612,7 +1611,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
=
[]
num_generation_tokens_requests
:
List
[
int
]
=
[]
best_of_requests
:
List
[
int
]
=
[]
n_requests
:
List
[
int
]
=
[]
finished_reason_requests
:
List
[
str
]
=
[]
...
...
@@ -1683,8 +1681,6 @@ class LLMEngine:
for
seq
in
seq_group
.
get_finished_seqs
()
])
if
seq_group
.
sampling_params
is
not
None
:
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
finished_reason_requests
.
extend
([
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
...
...
@@ -1737,7 +1733,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests
=
num_prompt_tokens_requests
,
num_generation_tokens_requests
=
num_generation_tokens_requests
,
best_of_requests
=
best_of_requests
,
n_requests
=
n_requests
,
finished_reason_requests
=
finished_reason_requests
,
)
...
...
@@ -1824,8 +1819,6 @@ class LLMEngine:
seq_group
.
sampling_params
.
top_p
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
,
seq_group
.
sampling_params
.
max_tokens
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
,
seq_group
.
sampling_params
.
best_of
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_N
,
seq_group
.
sampling_params
.
n
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_USAGE_NUM_SEQUENCES
,
...
...
vllm/engine/metrics.py
View file @
cbc2ef55
...
...
@@ -134,12 +134,6 @@ class Metrics:
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
self
.
_histogram_cls
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
self
.
_histogram_cls
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
...
...
@@ -473,8 +467,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
metrics
.
histogram_num_generation_tokens_request
,
stats
.
num_generation_tokens_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_n_request
,
stats
.
n_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_best_of_request
,
stats
.
best_of_requests
)
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
...
...
vllm/engine/metrics_types.py
View file @
cbc2ef55
...
...
@@ -49,7 +49,6 @@ class Stats:
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
best_of_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
...
...
vllm/engine/output_processor/single_step.py
View file @
cbc2ef55
...
...
@@ -112,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs
:
SequenceGroupOutput
,
is_async
:
bool
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
best_of
==
1
:
if
sampling_params
.
n
==
1
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
# only have one sequence
...
...
vllm/model_executor/layers/sampler.py
View file @
cbc2ef55
...
...
@@ -508,7 +508,7 @@ def _random_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum
best_of
value of the prompt phase requests.
# Find the maximum
n
value of the prompt phase requests.
random_samples
=
random_samples
.
cpu
()
sample_idx
=
0
results
:
SampleResultType
=
[]
...
...
@@ -523,9 +523,9 @@ def _random_sample(
num_parent_seqs
=
len
(
seq_ids
)
if
is_prompt
:
# Prompt phase.
parent_ids
=
[
0
]
*
sampling_params
.
best_of
parent_ids
=
[
0
]
*
sampling_params
.
n
next_token_ids
=
random_samples
[
sample_idx
,
:
sampling_params
.
best_of
].
tolist
()
sample_idx
,
:
sampling_params
.
n
].
tolist
()
else
:
# Generation phase.
parent_ids
=
list
(
range
(
num_parent_seqs
))
...
...
@@ -570,7 +570,7 @@ def _beam_search_sample(
is_prompt
=
seq_group
.
is_prompt
seq_ids
,
sampling_params
=
seq_group
.
seq_ids
,
seq_group
.
sampling_params
num_parent_seqs
=
len
(
seq_ids
)
beam_width
=
sampling_params
.
best_of
beam_width
=
sampling_params
.
n
seq_group_logprobs
=
logprobs
[
sample_idx
:
sample_idx
+
num_parent_seqs
]
if
is_prompt
:
# Prompt phase.
...
...
@@ -797,12 +797,11 @@ def _sample_with_torch(
greedy_samples
)
elif
sampling_type
in
(
SamplingType
.
RANDOM
,
SamplingType
.
RANDOM_SEED
):
max_
best_of
_in_batch
=
1
max_
n
_in_batch
=
1
for
seq_group
in
seq_groups
:
if
seq_group
.
is_prompt
:
sampling_params
=
seq_group
.
sampling_params
max_best_of_in_batch
=
max
(
max_best_of_in_batch
,
sampling_params
.
best_of
)
max_n_in_batch
=
max
(
max_n_in_batch
,
sampling_params
.
n
)
seq_groups_arg
=
(
None
if
sampling_type
==
SamplingType
.
RANDOM
else
seq_groups
)
...
...
@@ -812,13 +811,13 @@ def _sample_with_torch(
probs
[
long_sample_indices
],
sampling_tensors
.
top_ks
[
long_sample_indices
],
sampling_tensors
.
top_ps
[
long_sample_indices
],
max_
best_of
_in_batch
,
max_
n
_in_batch
,
seq_groups_arg
,
)
else
:
multinomial_samples
[
sampling_type
]
=
_multinomial
(
probs
[
long_sample_indices
],
max_
best_of
_in_batch
,
max_
n
_in_batch
,
seq_groups
=
seq_groups_arg
)
if
sampled_token_ids_tensor
is
not
None
:
...
...
vllm/outputs.py
View file @
cbc2ef55
...
...
@@ -141,7 +141,7 @@ class RequestOutput:
top_n_seqs
=
seqs
else
:
# Get the top-n sequences.
n
=
sampling_params
.
n
n
=
sampling_params
.
_real_n
or
sampling_params
.
n
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
...
...
vllm/sampling_params.py
View file @
cbc2ef55
...
...
@@ -106,9 +106,8 @@ class SamplingParams(
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
...
...
@@ -173,6 +172,7 @@ class SamplingParams(
n
:
int
=
1
best_of
:
Optional
[
int
]
=
None
_real_n
:
Optional
[
int
]
=
None
presence_penalty
:
float
=
0.0
frequency_penalty
:
float
=
0.0
repetition_penalty
:
float
=
1.0
...
...
@@ -282,7 +282,19 @@ class SamplingParams(
)
def
__post_init__
(
self
)
->
None
:
self
.
best_of
=
self
.
best_of
or
self
.
n
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if
self
.
best_of
:
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
self
.
_real_n
=
self
.
n
self
.
n
=
self
.
best_of
if
0
<
self
.
temperature
<
_MAX_TEMP
:
logger
.
warning
(
"temperature %s is less than %s, which may cause numerical "
...
...
@@ -329,12 +341,6 @@ class SamplingParams(
f
"type
{
type
(
self
.
n
)
}
"
)
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
if
not
isinstance
(
self
.
best_of
,
int
):
raise
ValueError
(
f
"best_of must be an int, but is of "
f
"type
{
type
(
self
.
best_of
)
}
"
)
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
if
not
-
2.0
<=
self
.
presence_penalty
<=
2.0
:
raise
ValueError
(
"presence_penalty must be in [-2, 2], got "
f
"
{
self
.
presence_penalty
}
."
)
...
...
@@ -385,7 +391,7 @@ class SamplingParams(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
if
self
.
best_of
!=
self
.
n
and
self
.
output_kind
==
(
if
self
.
best_of
!=
self
.
_real_
n
and
self
.
output_kind
==
(
RequestOutputKind
.
DELTA
):
raise
ValueError
(
"best_of must equal n to use output_kind=DELTA"
)
...
...
@@ -393,10 +399,6 @@ class SamplingParams(
if
self
.
n
>
1
:
raise
ValueError
(
"n must be 1 when using greedy sampling, "
f
"got
{
self
.
n
}
."
)
assert
isinstance
(
self
.
best_of
,
int
)
if
self
.
best_of
>
1
:
raise
ValueError
(
"best_of must be 1 when using greedy sampling, "
f
"got
{
self
.
best_of
}
."
)
def
update_from_generation_config
(
self
,
...
...
@@ -453,7 +455,6 @@ class SamplingParams(
def
__repr__
(
self
)
->
str
:
return
(
f
"SamplingParams(n=
{
self
.
n
}
, "
f
"best_of=
{
self
.
best_of
}
, "
f
"presence_penalty=
{
self
.
presence_penalty
}
, "
f
"frequency_penalty=
{
self
.
frequency_penalty
}
, "
f
"repetition_penalty=
{
self
.
repetition_penalty
}
, "
...
...
vllm/sequence.py
View file @
cbc2ef55
...
...
@@ -803,14 +803,14 @@ class SequenceGroup:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if
self
.
sampling_params
:
best_of
=
self
.
sampling_params
.
best_of
assert
isinstance
(
best_of
,
int
)
if
best_of
>
self
.
num_seqs
():
n
=
self
.
sampling_params
.
n
assert
isinstance
(
n
,
int
)
if
n
>
self
.
num_seqs
():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `
best_of
` sequences
# generation stage, we will have `
n
` sequences
# running.
return
best_of
return
n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return
self
.
num_unfinished_seqs
()
...
...
vllm/tracing.py
View file @
cbc2ef55
...
...
@@ -96,7 +96,6 @@ class SpanAttributes(BaseSpanAttributes):
# The following span attribute names are added here because they are missing
# from the Semantic Conventions for LLM.
LLM_REQUEST_ID
=
"gen_ai.request.id"
LLM_REQUEST_BEST_OF
=
"gen_ai.request.best_of"
LLM_REQUEST_N
=
"gen_ai.request.n"
LLM_USAGE_NUM_SEQUENCES
=
"gen_ai.usage.num_sequences"
LLM_LATENCY_TIME_IN_QUEUE
=
"gen_ai.latency.time_in_queue"
...
...
vllm/worker/tpu_model_runner.py
View file @
cbc2ef55
...
...
@@ -49,7 +49,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
t
:
torch
.
Tensor
p
:
torch
.
Tensor
num_samples
:
int
best_of
:
List
[
int
]
n
:
List
[
int
]
seq_groups
:
List
[
List
[
int
]]
is_first_multi_step
:
bool
=
True
is_last_step
:
bool
=
True
...
...
@@ -65,7 +65,7 @@ class ModelInputForTPU(ModelRunnerInputBase):
"t"
:
self
.
t
,
"p"
:
self
.
p
,
"num_samples"
:
self
.
num_samples
,
"
best_of
"
:
self
.
best_of
,
"
n
"
:
self
.
n
,
"seq_groups"
:
self
.
seq_groups
,
"is_first_multi_step"
:
self
.
is_first_multi_step
,
"is_last_step"
:
self
.
is_last_step
,
...
...
@@ -435,7 +435,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
assert
len
(
seq_group_metadata_list
)
>
0
t
=
[]
p
=
[]
best_of
=
[]
n
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
sampling_params
=
seq_group_metadata
.
sampling_params
t
.
append
(
sampling_params
.
temperature
)
...
...
@@ -448,11 +448,11 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
raise
NotImplementedError
(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues."
)
if
sampling_params
.
best_of
>
_MAX_NUM_SAMPLES
:
if
sampling_params
.
n
>
_MAX_NUM_SAMPLES
:
raise
NotImplementedError
(
f
"Best of >
{
_MAX_NUM_SAMPLES
}
is not supported by the TPU "
"backend."
)
best_of
.
append
(
sampling_params
.
best_of
)
n
.
append
(
sampling_params
.
n
)
if
sampling_params
.
logprobs
is
not
None
:
raise
NotImplementedError
(
"logprobs is not currently supported by the TPU backend."
)
...
...
@@ -465,7 +465,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
num_seqs
=
len
(
seq_group_metadata
.
seq_data
)
t
+=
[
t
[
-
1
]]
*
(
num_seqs
-
1
)
p
+=
[
p
[
-
1
]]
*
(
num_seqs
-
1
)
best_of
+=
[
best_of
[
-
1
]]
*
(
num_seqs
-
1
)
n
+=
[
n
[
-
1
]]
*
(
num_seqs
-
1
)
num_paddings
=
padded_batch_size
-
len
(
t
)
t
+=
[
1.0
]
*
num_paddings
...
...
@@ -473,7 +473,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
t
=
torch
.
tensor
(
t
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
p
=
torch
.
tensor
(
p
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
return
t
,
p
,
best_of
return
t
,
p
,
n
def
prepare_model_input
(
self
,
...
...
@@ -493,8 +493,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
=
inputs
padded_batch_size
=
input_tokens
.
shape
[
0
]
t
,
p
,
best_of
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
t
,
p
,
n
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
padded_batch_size
)
num_samples
=
_MAX_NUM_SAMPLES
if
is_prompt
else
1
seq_groups
=
[
...
...
@@ -502,8 +502,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
for
metadata
in
seq_group_metadata_list
]
return
ModelInputForTPU
(
input_tokens
,
input_positions
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
best_of
,
seq_groups
)
input_lens
,
t
,
p
,
num_samples
,
n
,
seq_groups
)
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
ModelInputForTPU
:
...
...
@@ -609,7 +608,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_outputs
=
[]
for
j
in
range
(
model_input
.
best_of
[
i
]):
for
j
in
range
(
model_input
.
n
[
i
]):
next_token_id
=
next_token_ids
[
i
][
j
]
seq_outputs
.
append
(
SequenceOutput
(
seq_id
,
next_token_id
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment