Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
428dd144
"xcode/vscode:/vscode.git/clone" did not exist on "27615dbc5fa74a9abfda13d301963c6c797ea21b"
Unverified
Commit
428dd144
authored
Aug 29, 2024
by
afeldman-nm
Committed by
GitHub
Aug 29, 2024
Browse files
[Core] Logprobs support in Multi-step (#7652)
parent
4abed65c
Changes
103
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
316 additions
and
107 deletions
+316
-107
tests/models/utils.py
tests/models/utils.py
+26
-17
tests/multi_step/test_correctness_async_llm.py
tests/multi_step/test_correctness_async_llm.py
+69
-30
tests/multi_step/test_correctness_llm.py
tests/multi_step/test_correctness_llm.py
+74
-21
tests/spec_decode/test_multi_step_worker.py
tests/spec_decode/test_multi_step_worker.py
+2
-1
tests/spec_decode/test_spec_decode_worker.py
tests/spec_decode/test_spec_decode_worker.py
+2
-1
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+2
-2
tests/test_sequence.py
tests/test_sequence.py
+3
-2
tests/utils.py
tests/utils.py
+60
-0
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+2
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-2
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+12
-3
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+46
-19
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+2
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+1
-1
vllm/executor/cpu_executor.py
vllm/executor/cpu_executor.py
+2
-1
vllm/executor/distributed_gpu_executor.py
vllm/executor/distributed_gpu_executor.py
+2
-1
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+2
-1
vllm/executor/gpu_executor.py
vllm/executor/gpu_executor.py
+2
-1
vllm/executor/multiproc_gpu_executor.py
vllm/executor/multiproc_gpu_executor.py
+2
-1
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+2
-1
No files found.
tests/models/utils.py
View file @
428dd144
import
warnings
from
typing
import
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
Logprob
,
SampleLogprobs
TokensText
=
Tuple
[
List
[
int
],
str
]
...
...
@@ -38,34 +38,39 @@ TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float
]],
SampleLogprobs
]]]
# Allow for tokens to be represented as str's rather than IDs
TextTextLogprobs
=
Tuple
[
List
[
str
],
str
,
Optional
[
Union
[
List
[
Dict
[
str
,
float
]],
List
[
Dict
[
str
,
Logprob
]]]]]
def
check_logprobs_close
(
*
,
outputs_0_lst
:
Sequence
[
TokensTextLogprobs
],
outputs_1_lst
:
Sequence
[
TokensTextLogprobs
],
outputs_0_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TextTextLogprobs
]
],
outputs_1_lst
:
Sequence
[
Union
[
TokensTextLogprobs
,
TextTextLogprobs
]
],
name_0
:
str
,
name_1
:
str
,
num_outputs_0_skip_tokens
:
int
=
0
,
warn_on_mismatch
:
bool
=
True
,
):
"""
Compare the logprobs of two sequences generated by different models,
always_check_logprobs
:
bool
=
False
,
)
->
None
:
"""
Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
Arguments:
* outputs_0_lst: First sequence to compare
* outputs_0_lst: Second sequence to compare
* name_0: sequence #0 name
* name_1: sequence #1 name
* num_outputs_0_skip_tokens: If > 0, specifies the number of initial
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
name_0: sequence #0 name
name_1: sequence #1 name
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
*
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
"""
assert
len
(
outputs_0_lst
)
==
len
(
outputs_1_lst
)
...
...
@@ -94,8 +99,12 @@ def check_logprobs_close(
for
idx
,
(
output_id_0
,
output_id_1
)
in
enumerate
(
zip
(
output_ids_0
,
output_ids_1
)):
# If generated tokens don't match, then
if
output_id_0
!=
output_id_1
:
is_tok_mismatch
=
output_id_0
!=
output_id_1
# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if
is_tok_mismatch
or
always_check_logprobs
:
logprobs_elem_0
=
logprobs_0
[
idx
]
logprobs_elem_1
=
logprobs_1
[
idx
]
...
...
@@ -111,7 +120,7 @@ def check_logprobs_close(
assert
output_id_0
in
logprobs_elem_1
,
fail_msg
assert
output_id_1
in
logprobs_elem_0
,
fail_msg
if
warn_on_mismatch
:
if
warn_on_mismatch
and
is_tok_mismatch
:
with
warnings
.
catch_warnings
():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
...
...
tests/multi_step/test_correctness_async_llm.py
View file @
428dd144
# Test the AsyncLLMEngine with multi-step-decoding
from
typing
import
List
from
typing
import
List
,
Optional
import
pytest
from
..utils
import
RemoteOpenAIServer
from
..models.utils
import
check_logprobs_close
from
..utils
import
(
completions_with_server_args
,
get_client_text_generations
,
get_client_text_logprob_generations
)
MODELS
=
[
"JackFram/llama-160m"
,
...
...
@@ -23,22 +25,6 @@ DEFAULT_SERVER_ARGS: List[str] = [
]
async
def
completions_with_server_args
(
prompts
:
List
[
str
],
model_name
:
str
,
server_cli_args
:
List
[
str
]):
outputs
=
None
with
RemoteOpenAIServer
(
model_name
,
server_cli_args
)
as
server
:
async
with
server
.
get_async_client
()
as
client
:
outputs
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompts
,
temperature
=
0
,
stream
=
False
,
max_tokens
=
5
)
assert
outputs
is
not
None
return
outputs
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
((
"tp_size, pp_size"
),
[
(
1
,
1
),
...
...
@@ -47,12 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@
pytest
.
mark
.
parametrize
(
"eager_mode"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
@
pytest
.
mark
.
parametrize
(
"is_async"
,
[
False
,
True
])
@
pytest
.
mark
.
asyncio
async
def
test_multi_step
(
example_prompts
,
model
:
str
,
tp_size
:
int
,
pp_size
:
int
,
eager_mode
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
is_async
:
bool
):
async
def
test_multi_step
(
example_prompts
,
model
:
str
,
tp_size
:
int
,
pp_size
:
int
,
eager_mode
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
is_async
:
bool
,
num_logprobs
:
Optional
[
int
],
)
->
None
:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
Set up an engine with single-step scheduling as a ground-truth reference.
Send a completions API request to both engines with the same prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
tp_size: degree of tensor-parallelism
pp_size: degree of pipeline-parallelism
eager_mode
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -77,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
str
(
pp_size
),
]
# Spin up client/server & issue completion API requests.
# Default `max_wait_seconds` is 240 but was empirically
# was raised 3x to 720 *just for this test* due to
# observed timeouts in GHA CI
ref_completions
=
await
completions_with_server_args
(
prompts
,
model
,
server_args
+
distributed_args
)
prompts
,
model
,
server_args
+
distributed_args
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
test_completions
=
await
completions_with_server_args
(
prompts
,
model
,
ms_server_args
+
distributed_args
)
def
get_text_generations
(
completions
):
return
[
x
.
text
for
x
in
completions
.
choices
]
ref_generations
=
get_text_generations
(
ref_completions
)
test_generations
=
get_text_generations
(
test_completions
)
prompts
,
model
,
ms_server_args
+
distributed_args
,
num_logprobs
,
max_wait_seconds
=
3
*
240
)
# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
ref_generations
=
get_client_text_generations
(
ref_completions
)
test_generations
=
get_client_text_generations
(
test_completions
)
assert
ref_generations
==
test_generations
# Assert multi-step scheduling produces nearly-identical logprobs
# to single-step scheduling.
ref_text_logprobs
=
get_client_text_logprob_generations
(
ref_completions
)
test_text_logprobs
=
get_client_text_logprob_generations
(
test_completions
)
check_logprobs_close
(
outputs_0_lst
=
ref_text_logprobs
,
outputs_1_lst
=
test_text_logprobs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
tests/multi_step/test_correctness_llm.py
View file @
428dd144
# Test the LLMEngine with multi-step-decoding
from
typing
import
Optional
import
pytest
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_logprobs_close
,
check_outputs_equal
MODELS
=
[
"JackFram/llama-160m"
,
...
...
@@ -18,10 +20,45 @@ NUM_PROMPTS = [10]
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"num_scheduler_steps"
,
NUM_SCHEDULER_STEPS
)
@
pytest
.
mark
.
parametrize
(
"num_prompts"
,
NUM_PROMPTS
)
def
test_multi_step_llm
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
tp_size
:
int
,
max_tokens
:
int
,
enforce_eager
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
None
,
5
])
def
test_multi_step_llm
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
tp_size
:
int
,
max_tokens
:
int
,
enforce_eager
:
int
,
num_scheduler_steps
:
int
,
num_prompts
:
int
,
num_logprobs
:
Optional
[
int
],
)
->
None
:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Set up a HuggingFace (HF) transformers model as a ground-truth reference.
Prompt them with the same example prompts.
Validate:
* Generated tokens match
* Generated logprobs are all very close
Args:
hf_runner: HF transformers model runner fixture
vllm_runner: vLLM model runner fixture
example_prompts: test fixture providing example prompts
model: model under test (same for single- and multi-step engines)
dtype: tensor datatype for engine to utilize
tp_size: degree of tensor-parallelism
max_tokens: the maximum number of tokens to generate
enforce_eager
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
GPU -> CPU output transfer
num_prompts: number of example prompts under test
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> no logprobs
"""
prompts
=
example_prompts
if
len
(
prompts
)
<
num_prompts
:
...
...
@@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str,
prompts
=
prompts
[:
num_prompts
]
assert
len
(
prompts
)
==
num_prompts
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
,
tensor_parallel_size
=
tp_size
,
use_v2_block_manager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
,
tensor_parallel_size
=
tp_size
,
use_v2_block_manager
=
True
,
num_scheduler_steps
=
num_scheduler_steps
,
)
as
vllm_model
:
vllm_outputs
=
(
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
)
if
num_logprobs
is
None
else
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
))
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
hf_outputs
=
(
hf_model
.
generate_greedy
(
prompts
,
max_tokens
)
if
num_logprobs
is
None
else
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
))
if
num_logprobs
is
None
:
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
else
:
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
tests/spec_decode/test_multi_step_worker.py
View file @
428dd144
...
...
@@ -5,9 +5,10 @@ from unittest.mock import MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
ExecuteModelRequest
,
HiddenStates
,
Logprob
,
SamplerOutput
,
get_all_seq_ids
)
get_all_seq_ids
)
from
vllm.spec_decode.draft_model_runner
import
TP1DraftModelRunner
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
...
tests/spec_decode/test_spec_decode_worker.py
View file @
428dd144
...
...
@@ -7,8 +7,9 @@ from unittest.mock import MagicMock
import
pytest
import
torch
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
,
SequenceOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceOutput
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.metrics
import
(
AsyncMetricsCollector
,
SpecDecodeWorkerMetrics
)
...
...
tests/spec_decode/utils.py
View file @
428dd144
...
...
@@ -8,12 +8,12 @@ from unittest.mock import MagicMock
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
Logprob
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
...
...
tests/test_sequence.py
View file @
428dd144
...
...
@@ -2,9 +2,10 @@ from array import array
import
pytest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
S
amplerOutput
,
SequenceData
,
SequenceOutput
)
CompletionSequenceGroupOutput
,
S
equenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
...
...
tests/utils.py
View file @
428dd144
...
...
@@ -11,9 +11,11 @@ from typing import Any, Callable, Dict, List, Optional
import
openai
import
requests
from
openai.types.completion
import
Completion
from
transformers
import
AutoTokenizer
from
typing_extensions
import
ParamSpec
from
tests.models.utils
import
TextTextLogprobs
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
...
...
@@ -432,3 +434,61 @@ def fork_new_process_for_each_test(
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
return
wrapper
async
def
completions_with_server_args
(
prompts
:
List
[
str
],
model_name
:
str
,
server_cli_args
:
List
[
str
],
num_logprobs
:
Optional
[
int
],
max_wait_seconds
:
int
=
240
,
)
->
Completion
:
'''Construct a remote OpenAI server, obtain an async client to the
server & invoke the completions API to obtain completions.
Args:
prompts: test prompts
model_name: model to spin up on the vLLM server
server_cli_args: CLI args for starting the server
num_logprobs: Number of logprobs to report (or `None`)
max_wait_seconds: timeout interval for bringing up server.
Default: 240sec
Returns:
OpenAI Completion instance
'''
outputs
=
None
with
RemoteOpenAIServer
(
model_name
,
server_cli_args
,
max_wait_seconds
=
max_wait_seconds
)
as
server
:
client
=
server
.
get_async_client
()
outputs
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompts
,
temperature
=
0
,
stream
=
False
,
max_tokens
=
5
,
logprobs
=
num_logprobs
)
assert
outputs
is
not
None
return
outputs
def
get_client_text_generations
(
completions
:
Completion
)
->
List
[
str
]:
'''Extract generated tokens from the output of a
request made to an Open-AI-protocol completions endpoint.
'''
return
[
x
.
text
for
x
in
completions
.
choices
]
def
get_client_text_logprob_generations
(
completions
:
Completion
)
->
List
[
TextTextLogprobs
]:
'''Operates on the output of a request made to an Open-AI-protocol
completions endpoint; obtains top-rank logprobs for each token in
each :class:`SequenceGroup`
'''
text_generations
=
get_client_text_generations
(
completions
)
text
=
''
.
join
(
text_generations
)
return
[(
text_generations
,
text
,
(
None
if
x
.
logprobs
is
None
else
x
.
logprobs
.
top_logprobs
))
for
x
in
completions
.
choices
]
vllm/engine/async_llm_engine.py
View file @
428dd144
...
...
@@ -22,11 +22,12 @@ from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
print_warning_once
...
...
vllm/engine/llm_engine.py
View file @
428dd144
...
...
@@ -33,6 +33,7 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
...
...
@@ -40,8 +41,8 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
...
...
vllm/engine/output_processor/multi_step.py
View file @
428dd144
...
...
@@ -4,6 +4,8 @@ from typing import Callable, List
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.single_step
import
(
single_step_process_prompt_logprob
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
...
...
@@ -46,9 +48,16 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self
.
_log_prompt_logprob_unsupported_warning_once
()
"""Process prompt logprobs associated with each step of a multi-step-
scheduled computation.
Args:
seq_group: the outputs are associated with this :class:`SequenceGroup`
outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
"""
for
output
in
outputs
:
# Concatenate single-step prompt logprob processing results.
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
@
functools
.
lru_cache
()
...
...
vllm/engine/output_processor/single_step.py
View file @
428dd144
...
...
@@ -15,6 +15,44 @@ from vllm.utils import Counter
logger
=
init_logger
(
__name__
)
def
single_step_process_prompt_logprob
(
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
output
:
SequenceGroupOutput
)
->
None
:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
Do nothing if the output has no prompt logprobs.
Account for the fact that transformers do not compute first-token logprobs.
Args:
sg_output_proc: :class:`SequenceGroupOutputProcessor` instance
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
prompt_logprobs
=
output
.
prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if
prompt_logprobs
is
not
None
:
if
not
seq_group
.
prompt_logprobs
:
prompt_logprobs
=
[
None
]
+
prompt_logprobs
seq_group
.
prompt_logprobs
=
[]
assert
hasattr
(
sg_output_proc
,
'detokenizer'
)
if
(
seq_group
.
sampling_params
.
detokenize
and
sg_output_proc
.
detokenizer
):
sg_output_proc
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
,
position_offset
=
len
(
seq_group
.
prompt_logprobs
))
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
class
SingleStepOutputProcessor
(
SequenceGroupOutputProcessor
):
"""SequenceGroupOutputProcessor which handles "output processing" logic,
which happens after the model returns generated token ids and before
...
...
@@ -60,27 +98,16 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Process prompt logprobs associated with one step of a single-step-
scheduled computation.
Args:
seq_group: the output is associated with this :class:`SequenceGroup`
output: the :class:`SequenceGroupOutput` for a single scheduler step
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
prompt_logprobs
=
output
.
prompt_logprobs
# If this is the first (or only) "chunk" of the prefill, we need
# to prepend None to the list of prompt logprobs. The reason for this
# is that for N prompt tokens, the Sampler will generate N-1 total
# prompt logprobs during prefill since the token at idx 0 will not
# have a logprob associated with it.
if
prompt_logprobs
is
not
None
:
if
not
seq_group
.
prompt_logprobs
:
prompt_logprobs
=
[
None
]
+
prompt_logprobs
seq_group
.
prompt_logprobs
=
[]
if
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
,
position_offset
=
len
(
seq_group
.
prompt_logprobs
))
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
,
...
...
vllm/engine/output_processor/util.py
View file @
428dd144
...
...
@@ -2,7 +2,8 @@ from typing import List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.sequence
import
PoolerOutput
,
SamplerOutput
,
SequenceGroupOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
PoolerOutput
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
...
...
vllm/engine/protocol.py
View file @
428dd144
...
...
@@ -5,11 +5,11 @@ from vllm.config import DecodingConfig, ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
...
...
vllm/executor/cpu_executor.py
View file @
428dd144
...
...
@@ -11,8 +11,9 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
GiB_bytes
,
get_distributed_init_method
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
vllm/executor/distributed_gpu_executor.py
View file @
428dd144
...
...
@@ -6,7 +6,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
logger
=
init_logger
(
__name__
)
...
...
vllm/executor/executor_base.py
View file @
428dd144
...
...
@@ -6,8 +6,9 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
class
ExecutorBase
(
ABC
):
...
...
vllm/executor/gpu_executor.py
View file @
428dd144
...
...
@@ -3,8 +3,9 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
from
vllm.worker.worker_base
import
WorkerBase
,
WorkerWrapperBase
...
...
vllm/executor/multiproc_gpu_executor.py
View file @
428dd144
...
...
@@ -14,7 +14,8 @@ from vllm.executor.gpu_executor import create_worker
from
vllm.executor.multiproc_worker_utils
import
(
ProcessWorkerWrapper
,
ResultHandler
,
WorkerMonitor
)
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.triton_utils
import
maybe_set_triton_cache_manager
from
vllm.utils
import
(
_run_task_with_lock
,
cuda_device_count_stateless
,
get_distributed_init_method
,
get_open_port
,
...
...
vllm/executor/neuron_executor.py
View file @
428dd144
...
...
@@ -3,7 +3,8 @@ from typing import List, Set, Tuple
from
vllm.executor.executor_base
import
ExecutorAsyncBase
,
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
(
get_distributed_init_method
,
get_ip
,
get_open_port
,
make_async
)
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment