Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a8c1d161
Unverified
Commit
a8c1d161
authored
Sep 18, 2024
by
afeldman-nm
Committed by
GitHub
Sep 18, 2024
Browse files
[Core] *Prompt* logprobs support in Multi-step (#8199)
parent
7c7714d8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
300 additions
and
59 deletions
+300
-59
tests/conftest.py
tests/conftest.py
+52
-32
tests/models/utils.py
tests/models/utils.py
+102
-6
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+92
-0
tests/utils.py
tests/utils.py
+2
-1
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+52
-20
No files found.
tests/conftest.py
View file @
a8c1d161
...
...
@@ -20,6 +20,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature
)
from
transformers.models.auto.auto_factory
import
_BaseAutoModelClass
from
tests.models.utils
import
(
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
)
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.video
import
VideoAsset
...
...
@@ -33,7 +35,6 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list
,
zip_enc_dec_prompts
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
SampleLogprobs
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
cuda_device_count_stateless
,
identity
,
is_cpu
)
...
...
@@ -469,7 +470,7 @@ class HfRunner:
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
T
uple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
]:
)
->
List
[
T
okensTextLogprobs
]:
all_logprobs
:
List
[
List
[
Dict
[
int
,
float
]]]
=
[]
all_output_ids
:
List
[
List
[
int
]]
=
[]
all_output_strs
:
List
[
str
]
=
[]
...
...
@@ -525,7 +526,7 @@ class HfRunner:
max_tokens
:
int
,
num_logprobs
:
int
,
**
kwargs
:
Any
,
)
->
List
[
T
uple
[
List
[
int
],
str
,
List
[
Dict
[
int
,
float
]]]
]:
)
->
List
[
T
okensTextLogprobs
]:
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
...
...
@@ -653,14 +654,16 @@ class VllmRunner:
@
staticmethod
def
_final_steps_generate_w_logprobs
(
req_outputs
:
List
[
RequestOutput
],
)
->
List
[
T
uple
[
List
[
int
],
str
,
Optional
[
Sample
Logprobs
]
]]
:
outputs
:
List
[
T
uple
[
List
[
int
],
str
,
Optional
[
Sample
Logprobs
]
]]
=
[]
)
->
List
[
T
okensTextLogprobsPrompt
Logprobs
]:
outputs
:
List
[
T
okensTextLogprobsPrompt
Logprobs
]
=
[]
for
req_output
in
req_outputs
:
assert
len
(
req_output
.
outputs
)
>
0
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_ids
=
list
(
sample
.
token_ids
)
output_logprobs
=
sample
.
logprobs
outputs
.
append
((
output_ids
,
output_str
,
output_logprobs
))
outputs
.
append
((
output_ids
,
output_str
,
output_logprobs
,
req_output
.
prompt_logprobs
))
return
outputs
def
generate_w_logprobs
(
...
...
@@ -670,7 +673,8 @@ class VllmRunner:
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
assert
sampling_params
.
logprobs
is
not
None
if
images
is
not
None
:
...
...
@@ -695,13 +699,20 @@ class VllmRunner:
req_outputs
=
self
.
model
.
generate
(
inputs
,
sampling_params
=
sampling_params
)
return
self
.
_final_steps_generate_w_logprobs
(
req_outputs
)
toks_str_logsprobs_prompt_logprobs
=
(
self
.
_final_steps_generate_w_logprobs
(
req_outputs
))
# Omit prompt logprobs if not required by sampling params
return
([
x
[
0
:
-
1
]
for
x
in
toks_str_logsprobs_prompt_logprobs
]
if
sampling_params
.
prompt_logprobs
is
None
else
toks_str_logsprobs_prompt_logprobs
)
def
generate_encoder_decoder_w_logprobs
(
self
,
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''
...
...
@@ -709,7 +720,12 @@ class VllmRunner:
assert
sampling_params
.
logprobs
is
not
None
req_outputs
=
self
.
model
.
generate
(
encoder_decoder_prompts
,
sampling_params
=
sampling_params
)
return
self
.
_final_steps_generate_w_logprobs
(
req_outputs
)
toks_str_logsprobs_prompt_logprobs
=
(
self
.
_final_steps_generate_w_logprobs
(
req_outputs
))
# Omit prompt logprobs if not required by sampling params
return
([
x
[
0
:
-
1
]
for
x
in
toks_str_logsprobs_prompt_logprobs
]
if
sampling_params
.
prompt_logprobs
is
None
else
toks_str_logsprobs_prompt_logprobs
)
def
generate_greedy
(
self
,
...
...
@@ -727,44 +743,48 @@ class VllmRunner:
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
,
stop_token_ids
=
stop_token_ids
)
outputs
=
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
,
audios
=
audios
,
videos
=
videos
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
,
prompt_logprobs
=
(
num_prompt_logprobs
),
stop_token_ids
=
stop_token_ids
)
return
self
.
generate_w_logprobs
(
prompts
,
greedy_logprobs_params
,
images
=
images
,
audios
=
audios
,
videos
=
videos
)
def
generate_encoder_decoder_greedy_logprobs
(
self
,
encoder_decoder_prompts
:
List
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
max_tokens
:
int
,
num_logprobs
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
use_beam_search
=
False
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
)
num_prompt_logprobs
:
Optional
[
int
]
=
None
,
)
->
Union
[
List
[
TokensTextLogprobs
],
List
[
TokensTextLogprobsPromptLogprobs
]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
use_beam_search
=
False
,
max_tokens
=
max_tokens
,
logprobs
=
num_logprobs
,
prompt_logprobs
=
(
num_prompt_logprobs
),
)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
outputs
=
self
.
generate_encoder_decoder_w_logprobs
(
return
self
.
generate_encoder_decoder_w_logprobs
(
encoder_decoder_prompts
,
greedy_logprobs_params
)
return
[(
output_ids
,
output_str
,
output_logprobs
)
for
output_ids
,
output_str
,
output_logprobs
in
outputs
]
def
generate_beam_search
(
self
,
prompts
:
List
[
str
],
...
...
tests/models/utils.py
View file @
a8c1d161
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
vllm.sequence
import
Logprob
,
SampleLogprobs
from
vllm.sequence
import
Logprob
,
PromptLogprobs
,
SampleLogprobs
TokensText
=
Tuple
[
List
[
int
],
str
]
...
...
@@ -34,20 +34,47 @@ def check_outputs_equal(
assert
output_ids_0
==
output_ids_1
,
fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs
=
Tuple
[
List
[
int
],
str
,
Optional
[
Union
[
List
[
Dict
[
int
,
float
]],
SampleLogprobs
]]]
# Allow for tokens to be represented as str's rather than IDs
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs
=
Tuple
[
List
[
str
],
str
,
Optional
[
Union
[
List
[
Dict
[
str
,
float
]],
List
[
Dict
[
str
,
Logprob
]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs
=
Tuple
[
List
[
int
],
str
,
Optional
[
Union
[
List
[
Dict
[
int
,
float
]],
SampleLogprobs
]],
Optional
[
Union
[
List
[
Optional
[
Dict
[
int
,
float
]]],
PromptLogprobs
]]]
def
check_logprobs_close
(
*
,
outputs_0_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TextTextLogprobs
]],
outputs_1_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TextTextLogprobs
]],
outputs_0_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
,
TextTextLogprobs
]],
outputs_1_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TokensTextLogprobsPromptLogprobs
,
TextTextLogprobs
]],
name_0
:
str
,
name_1
:
str
,
num_outputs_0_skip_tokens
:
int
=
0
,
...
...
@@ -57,6 +84,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
...
...
@@ -78,8 +117,65 @@ def check_logprobs_close(
for
prompt_idx
,
(
outputs_0
,
outputs_1
)
in
enumerate
(
zip
(
outputs_0_lst
,
outputs_1_lst
)):
output_ids_0
,
output_str_0
,
logprobs_0
=
outputs_0
output_ids_1
,
output_str_1
,
logprobs_1
=
outputs_1
assert
len
(
outputs_0
)
==
len
(
outputs_1
)
if
len
(
outputs_0
)
==
3
:
assert
len
(
outputs_1
)
==
3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0
,
output_str_0
,
logprobs_0
=
outputs_0
output_ids_1
,
output_str_1
,
logprobs_1
=
outputs_1
elif
len
(
outputs_0
)
==
4
:
assert
len
(
outputs_1
)
==
4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0
,
output_str_0
,
logprobs_0
,
prompt_logprobs_0
,
)
=
outputs_0
(
output_ids_1
,
output_str_1
,
logprobs_1
,
prompt_logprobs_1
,
)
=
outputs_1
# Test prompt logprobs closeness
if
(
prompt_logprobs_0
is
not
None
and
prompt_logprobs_1
is
not
None
):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for
idx
,
(
logprobs_elem_0
,
logprobs_elem_1
)
in
enumerate
(
zip
(
prompt_logprobs_0
,
prompt_logprobs_1
)):
fail_msg
=
(
f
"Prompt logprobs test:"
f
"
\n
{
name_0
}
:
\t
Prompt index
{
idx
}
\t
{
logprobs_elem_0
}
"
f
"
\n
{
name_1
}
:
\t
Prompt index
{
idx
}
\t
{
logprobs_elem_1
}
"
)
if
logprobs_elem_0
is
None
:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert
logprobs_elem_1
is
None
,
fail_msg
else
:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert
logprobs_elem_1
is
not
None
,
fail_msg
# Logprobs check: top-k token choices must be the same
assert
(
set
(
logprobs_elem_0
.
keys
())
==
set
(
logprobs_elem_1
.
keys
())),
fail_msg
else
:
# Both sequence logprobs lists must be `None`
fail_msg
=
(
f
"Prompt logprobs test:"
f
"
\n
{
name_0
}
:
\t
logprobs
\t
{
prompt_logprobs_0
}
"
f
"
\n
{
name_1
}
:
\t
logprobs
\t
{
prompt_logprobs_1
}
"
)
assert
(
prompt_logprobs_0
is
None
and
prompt_logprobs_1
is
None
),
fail_msg
else
:
raise
ValueError
(
f
"Outputs tuple must have 3 or 4 elements but "
f
"
{
len
(
outputs_0
)
}
elements were provided: "
f
"
{
outputs_0
}
"
)
if
logprobs_0
is
None
:
logprobs_0
=
[
None
]
*
len
(
output_ids_0
)
...
...
tests/multi_step/test_correctness_llm.py
View file @
a8c1d161
...
...
@@ -100,3 +100,95 @@ def test_multi_step_llm(
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs,num_prompt_logprobs"
,
[(
5
,
5
)])
def
test_multi_step_llm_w_prompt_logprobs
(
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
tp_size
:
int
,
max_tokens
:
int
,
enforce_eager
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
num_prompt_logprobs
:
Optional
[
int
],
)
->
None
:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
reference.
Prompt them with the same example prompts.
Validate:
* All generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
num_prompt_logprobs: number of logprobs to return for each prompt token;
note that this argument is not supported by the
OpenAI completions endpoint.
"""
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
prompts
=
prompts
*
((
num_prompts
//
len
(
prompts
))
+
1
)
prompts
=
prompts
[:
num_prompts
]
assert
len
(
prompts
)
==
num_prompts
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
,
tensor_parallel_size
=
tp_size
,
use_v2_block_manager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
,
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
num_prompt_logprobs
=
num_prompt_logprobs
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
,
tensor_parallel_size
=
tp_size
,
)
as
vllm_model
:
single_step_vllm_outputs
=
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
num_prompt_logprobs
=
num_prompt_logprobs
)
check_logprobs_close
(
outputs_0_lst
=
single_step_vllm_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
tests/utils.py
View file @
a8c1d161
...
...
@@ -493,6 +493,7 @@ async def completions_with_server_args(
'''
outputs
=
None
max_wait_seconds
=
240
*
3
# 240 is default
with
RemoteOpenAIServer
(
model_name
,
server_cli_args
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
...
...
@@ -503,7 +504,7 @@ async def completions_with_server_args(
stream
=
False
,
max_tokens
=
5
,
logprobs
=
num_logprobs
)
assert
outputs
is
not
None
assert
outputs
is
not
None
,
"Completion API call failed."
return
outputs
...
...
vllm/worker/multi_step_model_runner.py
View file @
a8c1d161
...
...
@@ -614,34 +614,66 @@ def _pythonize_sampler_output(
frozen_model_input
=
model_input
.
frozen_model_input
assert
frozen_model_input
.
sampling_metadata
is
not
None
sampling_metadata
=
frozen_model_input
.
sampling_metadata
# samples generation should have been skipped
assert
not
output
.
outputs
pinned_buffer
=
pinned_sampled_token_buffer
[:
model_input
.
num_queries
]
# CPU GPU sync
pinned_buffer
=
pinned_buffer
.
copy_
(
sampled_token_ids
,
non_blocking
=
False
)
# We guarantee output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However we should check whether logprobs pythonization may
# be skipped entirely, i.e. because no logprobs were requested
# or pythonization was not deferred. To that end,
#
# * `prompt_logprobs_are_requested_for_prefill` signals that
# there are *any* prefill-phase requests which specify that
# prompt logprobs should be returned.
#
# * `any_logprobs_are_requested` signals that there are any
# requests which (1) specify that sample logprobs should be
# returned, or (2) are in the prefill phase AND specify that
# prompt logprobs should be returned.
#
# Later on, these flags cause adjustments to the pythonization
# process to accommodate logprobs.
seq_groups
=
sampling_metadata
.
seq_groups
prompt_logprobs_are_requested_for_prefill
=
any
([
sg
.
sampling_params
.
prompt_logprobs
is
not
None
and
sg
.
is_prompt
for
sg
in
seq_groups
])
any_logprobs_are_requested
=
(
prompt_logprobs_are_requested_for_prefill
or
any
([
sg
.
sampling_params
.
logprobs
is
not
None
for
sg
in
seq_groups
]))
if
prompt_logprobs_are_requested_for_prefill
:
# CPU GPU sync, after gathering *only* sampled tokens (since
# requesting prompt logprobs leads `sampled_token_ids` to
# include prompt token ids in addition to sampled token ids.)
sample_idx_tensor
=
torch
.
tensor
(
[
sdx
for
sg
in
seq_groups
for
sdx
in
sg
.
sample_indices
])
pinned_buffer
=
pinned_buffer
.
copy_
(
sampled_token_ids
[
sample_idx_tensor
,
:],
non_blocking
=
False
)
else
:
# CPU GPU sync
pinned_buffer
=
pinned_buffer
.
copy_
(
sampled_token_ids
,
non_blocking
=
False
)
# this will not block as the tensors are already on CPU
samples_list
=
pinned_buffer
.
tolist
()
sampling_metadata
=
frozen_model_input
.
sampling_metadata
skip_sampler_cpu_output
=
(
frozen_model_input
.
sampling_metadata
.
skip_sampler_cpu_output
)
# We are guaranteed output tensors are ready, so it is safe to
# pythonize the sampler output & obtain CPU-side logprobs.
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups
=
sampling_metadata
.
seq_groups
logprobs_are_requested
=
any
([
sg
.
sampling_params
.
logprobs
is
not
None
or
sg
.
sampling_params
.
prompt_logprobs
is
not
None
for
sg
in
seq_groups
])
# *Don't* skip logprobs pythonization *if*:
# * Any requests require logprobs to be returned in this
# iteration AND
# * These requests are being scheduled in a fashion which
# defers pythonization (i.e. multi-step scheduling.)
do_pythonize_logprobs
=
(
skip_sampler_cpu_output
and
logprobs_are_requested
)
and
any_
logprobs_are_requested
)
(
prompt_logprobs
,
sample_logprobs
,
...
...
@@ -666,7 +698,7 @@ def _pythonize_sampler_output(
prompt_logprobs
[
sgdx
],
sample_logprobs
[
sgdx
],
)
elif
logprobs_are_requested
:
elif
any_
logprobs_are_requested
:
(
group_prompt_logprobs
,
group_sample_logprobs
,
...
...
@@ -696,7 +728,7 @@ def _pythonize_sampler_output(
seq_output
.
parent_seq_id
=
seq_ids
[
parent_id
]
seq_output
.
output_token
=
next_token_id
if
logprobs_are_requested
:
if
any_
logprobs_are_requested
:
seq_output
.
logprobs
=
group_sample_logprobs
[
tdx
]
else
:
logprobs
=
next
(
iter
(
seq_output
.
logprobs
.
values
()))
...
...
@@ -714,7 +746,7 @@ def _pythonize_sampler_output(
seq_outputs
.
append
(
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
(
group_sample_logprobs
[
tdx
]
if
logprobs_are_requested
else
{
if
any_
logprobs_are_requested
else
{
next_token_id
:
Logprob
(
logprob
=
float
(
'inf'
),
rank
=
None
,
...
...
@@ -722,12 +754,12 @@ def _pythonize_sampler_output(
})))
if
cache
is
not
None
:
completion_seq_group_output
.
prompt_logprobs
=
\
group_prompt_logprobs
if
logprobs_are_requested
else
None
group_prompt_logprobs
if
any_
logprobs_are_requested
else
None
output
.
outputs
.
append
(
completion_seq_group_output
)
else
:
output
.
outputs
.
append
(
CompletionSequenceGroupOutput
(
seq_outputs
,
(
group_prompt_logprobs
if
logprobs_are_requested
else
None
)))
if
any_
logprobs_are_requested
else
None
)))
assert
len
(
output
.
outputs
)
>
0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment