Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4167252e
Unverified
Commit
4167252e
authored
Mar 03, 2025
by
Mark McLoughlin
Committed by
GitHub
Mar 03, 2025
Browse files
[V1] Refactor parallel sampling support (#13774)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
f35f8e22
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
201 additions
and
464 deletions
+201
-464
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+21
-40
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+22
-52
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+112
-69
vllm/v1/engine/parallel_sampling.py
vllm/v1/engine/parallel_sampling.py
+44
-300
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+2
-3
No files found.
vllm/v1/engine/async_llm.py
View file @
4167252e
...
@@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext
...
@@ -25,7 +25,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.utils
import
cdiv
,
kill_process_tree
from
vllm.utils
import
cdiv
,
kill_process_tree
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
generate_parallel_sampling_async
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.metrics.loggers
import
(
LoggingStatLogger
,
PrometheusStatLogger
,
from
vllm.v1.metrics.loggers
import
(
LoggingStatLogger
,
PrometheusStatLogger
,
...
@@ -145,21 +145,26 @@ class AsyncLLM(EngineClient):
...
@@ -145,21 +145,26 @@ class AsyncLLM(EngineClient):
"""Add new request to the AsyncLLM."""
"""Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request.
# 1) Create a new output queue for the request.
if
self
.
output_processor
.
is_request_active
(
request_id
):
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
queue
:
asyncio
.
Queue
[
RequestOutput
]
=
asyncio
.
Queue
()
queue
:
asyncio
.
Queue
[
RequestOutput
]
=
asyncio
.
Queue
()
# 2) Convert Input --> Request.
# 2) Fan out child requests (for n>1)
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
for
idx
in
range
(
n
):
if
parent_req
is
not
None
:
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
# 3) Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
arrival_time
,
lora_request
,
trace_headers
,
trace_headers
,
prompt_adapter_request
,
prompt_adapter_request
,
priority
)
priority
)
#
3
) Add the request to OutputProcessor (this process).
#
4
) Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
queue
)
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
,
queue
)
#
4
) Add the EngineCoreRequest to EngineCore (separate process).
#
5
) Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
request
)
await
self
.
engine_core
.
add_request_async
(
request
)
if
self
.
log_requests
:
if
self
.
log_requests
:
...
@@ -172,7 +177,7 @@ class AsyncLLM(EngineClient):
...
@@ -172,7 +177,7 @@ class AsyncLLM(EngineClient):
# requests we don't need to send multiple messages to core proc,
# requests we don't need to send multiple messages to core proc,
# and so we don't need multiple streams which then get
# and so we don't need multiple streams which then get
# re-multiplexed in the API server anyhow.
# re-multiplexed in the API server anyhow.
async
def
_
generate
(
async
def
generate
(
self
,
self
,
prompt
:
PromptType
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
...
@@ -243,30 +248,6 @@ class AsyncLLM(EngineClient):
...
@@ -243,30 +248,6 @@ class AsyncLLM(EngineClient):
await
self
.
abort
(
request_id
)
await
self
.
abort
(
request_id
)
raise
raise
def
generate
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
kwargs
=
dict
(
prompt
=
prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
)
if
sampling_params
.
n
is
None
or
sampling_params
.
n
==
1
:
return
self
.
_generate
(
**
kwargs
)
else
:
# Special handling for parallel sampling requests
return
generate_parallel_sampling_async
(
generate
=
self
.
_generate
,
**
kwargs
)
async
def
_run_output_handler
(
self
):
async
def
_run_output_handler
(
self
):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
...
...
vllm/v1/engine/llm_engine.py
View file @
4167252e
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import (
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer_group import (
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
SyncParallelSamplingManager
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
...
@@ -50,9 +50,6 @@ class LLMEngine:
...
@@ -50,9 +50,6 @@ class LLMEngine:
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
# Bookkeeping for parallel sampling requests
self
.
parallel_manager
=
SyncParallelSamplingManager
()
# important: init dp group before init the engine_core
# important: init dp group before init the engine_core
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_enabled
=
self
.
parallel_config
.
data_parallel_size
>
1
# noqa
self
.
dp_enabled
=
self
.
parallel_config
.
data_parallel_size
>
1
# noqa
...
@@ -120,8 +117,7 @@ class LLMEngine:
...
@@ -120,8 +117,7 @@ class LLMEngine:
multiprocess_mode
=
enable_multiprocessing
)
multiprocess_mode
=
enable_multiprocessing
)
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
return
self
.
parallel_manager
.
get_num_unfinished_requests
(
return
self
.
output_processor
.
get_num_unfinished_requests
()
self
.
output_processor
.
get_num_unfinished_requests
())
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
...
@@ -157,45 +153,22 @@ class LLMEngine:
...
@@ -157,45 +153,22 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
None
:
)
->
None
:
"""Add request."""
# 1) Fan out child requests (for n>1)
kwargs
=
dict
(
request_id
=
request_id
,
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
prompt
=
prompt
,
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
params
=
params
,
for
idx
in
range
(
n
):
arrival_time
=
arrival_time
,
if
parent_req
is
not
None
:
lora_request
=
lora_request
,
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
# 2) Process raw inputs into the request.
priority
=
priority
)
# Handle parallel sampling requests differently.
if
params
is
None
or
isinstance
(
params
,
PoolingParams
)
or
params
.
n
==
1
:
self
.
_add_request
(
**
kwargs
)
else
:
# Special handling for parallel sampling requests
self
.
parallel_manager
.
add_request_parallel_sampling
(
add_request
=
self
.
_add_request
,
**
kwargs
)
def
_add_request
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
"""Add request, `n=1`"""
# 1) Process raw inputs into the request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
arrival_time
,
lora_request
,
trace_headers
,
trace_headers
,
prompt_adapter_request
,
prompt_adapter_request
,
priority
)
priority
)
#
2
) Make a new RequestState and queue.
#
3
) Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
request
)
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
)
# 3) Add the request to EngineCore.
# 3) Add the request to EngineCore.
self
.
engine_core
.
add_request
(
request
)
self
.
engine_core
.
add_request
(
request
)
...
@@ -217,10 +190,7 @@ class LLMEngine:
...
@@ -217,10 +190,7 @@ class LLMEngine:
# 3) Abort any reqs that finished due to stop strings.
# 3) Abort any reqs that finished due to stop strings.
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
request_outputs
=
processed_outputs
.
request_outputs
return
processed_outputs
.
request_outputs
# 4) Process unfinished parallel sampling requests
return
self
.
parallel_manager
.
step
(
request_outputs
)
def
get_model_config
(
self
):
def
get_model_config
(
self
):
return
self
.
model_config
return
self
.
model_config
...
...
vllm/v1/engine/output_processor.py
View file @
4167252e
...
@@ -4,13 +4,14 @@ import asyncio
...
@@ -4,13 +4,14 @@ import asyncio
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreRequest
,
FinishReason
from
vllm.v1.engine
import
EngineCoreOutput
,
EngineCoreRequest
,
FinishReason
from
vllm.v1.engine.detokenizer
import
IncrementalDetokenizer
from
vllm.v1.engine.detokenizer
import
IncrementalDetokenizer
from
vllm.v1.engine.logprobs
import
LogprobsProcessor
from
vllm.v1.engine.logprobs
import
LogprobsProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.metrics.stats
import
(
IterationStats
,
LoRARequestStates
,
from
vllm.v1.metrics.stats
import
(
IterationStats
,
LoRARequestStates
,
RequestStateStats
)
RequestStateStats
)
...
@@ -27,6 +28,8 @@ class RequestState:
...
@@ -27,6 +28,8 @@ class RequestState:
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
parent_req
:
Optional
[
ParentRequest
],
request_index
:
int
,
lora_name
:
Optional
[
str
],
lora_name
:
Optional
[
str
],
output_kind
:
RequestOutputKind
,
output_kind
:
RequestOutputKind
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
...
@@ -38,6 +41,8 @@ class RequestState:
...
@@ -38,6 +41,8 @@ class RequestState:
log_stats
:
bool
,
log_stats
:
bool
,
):
):
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
parent_req
=
parent_req
self
.
request_index
=
request_index
self
.
lora_name
=
lora_name
self
.
lora_name
=
lora_name
self
.
output_kind
=
output_kind
self
.
output_kind
=
output_kind
self
.
prompt
=
prompt
self
.
prompt
=
prompt
...
@@ -56,11 +61,15 @@ class RequestState:
...
@@ -56,11 +61,15 @@ class RequestState:
cls
,
cls
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
request
:
EngineCoreRequest
,
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
],
request_index
:
int
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]],
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]],
log_stats
:
bool
,
log_stats
:
bool
,
)
->
"RequestState"
:
)
->
"RequestState"
:
return
cls
(
return
cls
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
parent_req
=
parent_req
,
request_index
=
request_index
,
lora_name
=
(
request
.
lora_request
.
name
lora_name
=
(
request
.
lora_request
.
name
if
request
.
lora_request
is
not
None
else
None
),
if
request
.
lora_request
is
not
None
else
None
),
output_kind
=
request
.
sampling_params
.
output_kind
,
output_kind
=
request
.
sampling_params
.
output_kind
,
...
@@ -79,6 +88,88 @@ class RequestState:
...
@@ -79,6 +88,88 @@ class RequestState:
log_stats
=
log_stats
,
log_stats
=
log_stats
,
)
)
def
make_request_output
(
self
,
new_token_ids
:
list
[
int
],
finish_reason
:
Optional
[
FinishReason
],
stop_reason
:
Union
[
int
,
str
,
None
],
)
->
Optional
[
RequestOutput
]:
finished
=
finish_reason
is
not
None
output_kind
=
self
.
output_kind
final_only
=
output_kind
==
RequestOutputKind
.
FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if
not
finished
and
(
self
.
is_prefilling
or
final_only
):
# Only the final output is required in FINAL_ONLY mode.
return
None
def
new_request_output
(
request_id
:
str
)
->
RequestOutput
:
return
self
.
_new_request_output
(
request_id
,
finished
)
completion_output
=
self
.
_new_completion_output
(
new_token_ids
,
finish_reason
,
stop_reason
)
if
self
.
parent_req
is
not
None
:
return
self
.
parent_req
.
make_request_output
(
final_only
,
completion_output
,
new_request_output
)
request_output
=
new_request_output
(
self
.
request_id
)
request_output
.
outputs
.
append
(
completion_output
)
return
request_output
def
_new_request_output
(
self
,
request_id
:
str
,
finished
:
bool
,
)
->
RequestOutput
:
if
self
.
output_kind
==
RequestOutputKind
.
DELTA
:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs
=
self
.
logprobs_processor
.
pop_prompt_logprobs
()
else
:
prompt_logprobs
=
self
.
logprobs_processor
.
prompt_logprobs
return
RequestOutput
(
request_id
=
request_id
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
self
.
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
outputs
=
[],
finished
=
finished
,
)
def
_new_completion_output
(
self
,
token_ids
:
list
[
int
],
finish_reason
:
Optional
[
FinishReason
],
stop_reason
:
Union
[
int
,
str
,
None
],
)
->
CompletionOutput
:
finished
=
finish_reason
is
not
None
delta
=
self
.
output_kind
==
RequestOutputKind
.
DELTA
# Prepare text and token_ids, based on delta mode
text
=
self
.
detokenizer
.
get_next_output_text
(
finished
,
delta
)
if
not
delta
:
token_ids
=
self
.
detokenizer
.
output_token_ids
# Prepare logprobs, based on delta mode
logprobs
=
self
.
logprobs_processor
.
logprobs
if
delta
and
logprobs
:
logprobs
=
logprobs
[
-
len
(
token_ids
):]
return
CompletionOutput
(
index
=
self
.
request_index
,
text
=
text
,
token_ids
=
token_ids
,
logprobs
=
logprobs
,
cumulative_logprob
=
self
.
logprobs_processor
.
cumulative_logprob
,
finish_reason
=
str
(
finish_reason
)
if
finished
else
None
,
stop_reason
=
stop_reason
if
finished
else
None
)
class
OutputProcessor
:
class
OutputProcessor
:
"""Process EngineCoreOutputs into RequestOutputs."""
"""Process EngineCoreOutputs into RequestOutputs."""
...
@@ -93,9 +184,6 @@ class OutputProcessor:
...
@@ -93,9 +184,6 @@ class OutputProcessor:
self
.
request_states
:
dict
[
str
,
RequestState
]
=
{}
self
.
request_states
:
dict
[
str
,
RequestState
]
=
{}
self
.
lora_states
=
LoRARequestStates
()
self
.
lora_states
=
LoRARequestStates
()
def
is_request_active
(
self
,
request_id
:
str
)
->
bool
:
return
request_id
in
self
.
request_states
def
get_num_unfinished_requests
(
self
):
def
get_num_unfinished_requests
(
self
):
return
len
(
self
.
request_states
)
return
len
(
self
.
request_states
)
...
@@ -114,6 +202,8 @@ class OutputProcessor:
...
@@ -114,6 +202,8 @@ class OutputProcessor:
def
add_request
(
def
add_request
(
self
,
self
,
request
:
EngineCoreRequest
,
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
]
=
None
,
request_index
:
int
=
0
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]]
=
None
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]]
=
None
,
)
->
None
:
)
->
None
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
...
@@ -123,6 +213,8 @@ class OutputProcessor:
...
@@ -123,6 +213,8 @@ class OutputProcessor:
req_state
=
RequestState
.
from_new_request
(
req_state
=
RequestState
.
from_new_request
(
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
),
tokenizer
=
self
.
tokenizer
.
get_lora_tokenizer
(
request
.
lora_request
),
request
=
request
,
request
=
request
,
parent_req
=
parent_req
,
request_index
=
request_index
,
queue
=
queue
,
queue
=
queue
,
log_stats
=
self
.
log_stats
)
log_stats
=
self
.
log_stats
)
self
.
request_states
[
request_id
]
=
req_state
self
.
request_states
[
request_id
]
=
req_state
...
@@ -202,8 +294,8 @@ class OutputProcessor:
...
@@ -202,8 +294,8 @@ class OutputProcessor:
req_state
.
logprobs_processor
.
update_from_output
(
engine_core_output
)
req_state
.
logprobs_processor
.
update_from_output
(
engine_core_output
)
# 4) Create and handle RequestOutput objects.
# 4) Create and handle RequestOutput objects.
if
request_output
:
=
self
.
_
make_request_output
(
if
request_output
:
=
req_state
.
make_request_output
(
req_state
,
new_token_ids
,
finish_reason
,
stop_reason
):
new_token_ids
,
finish_reason
,
stop_reason
):
if
req_state
.
queue
is
not
None
:
if
req_state
.
queue
is
not
None
:
# AsyncLLM: put into queue for handling by generate().
# AsyncLLM: put into queue for handling by generate().
req_state
.
queue
.
put_nowait
(
request_output
)
req_state
.
queue
.
put_nowait
(
request_output
)
...
@@ -212,7 +304,7 @@ class OutputProcessor:
...
@@ -212,7 +304,7 @@ class OutputProcessor:
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
# Free completed requests.
# Free completed requests.
if
request_output
.
finished
:
if
finish_reason
is
not
None
:
self
.
request_states
.
pop
(
req_id
)
self
.
request_states
.
pop
(
req_id
)
if
not
engine_core_output
.
finished
:
if
not
engine_core_output
.
finished
:
# If req not finished in EngineCore, but Detokenizer
# If req not finished in EngineCore, but Detokenizer
...
@@ -220,8 +312,7 @@ class OutputProcessor:
...
@@ -220,8 +312,7 @@ class OutputProcessor:
reqs_to_abort
.
append
(
req_id
)
reqs_to_abort
.
append
(
req_id
)
# Track per-request stats
# Track per-request stats
self
.
_update_stats_from_finished
(
req_state
,
request_output
,
self
.
_update_stats_from_finished
(
req_state
,
finish_reason
,
finish_reason
,
iteration_stats
)
iteration_stats
)
self
.
lora_states
.
update_iteration_stats
(
iteration_stats
)
self
.
lora_states
.
update_iteration_stats
(
iteration_stats
)
...
@@ -249,7 +340,6 @@ class OutputProcessor:
...
@@ -249,7 +340,6 @@ class OutputProcessor:
req_state
.
stats
,
lora_stats
)
req_state
.
stats
,
lora_stats
)
def
_update_stats_from_finished
(
self
,
req_state
:
RequestState
,
def
_update_stats_from_finished
(
self
,
req_state
:
RequestState
,
request_output
:
RequestOutput
,
finish_reason
:
Optional
[
FinishReason
],
finish_reason
:
Optional
[
FinishReason
],
iteration_stats
:
Optional
[
IterationStats
]):
iteration_stats
:
Optional
[
IterationStats
]):
if
iteration_stats
is
None
:
if
iteration_stats
is
None
:
...
@@ -257,55 +347,8 @@ class OutputProcessor:
...
@@ -257,55 +347,8 @@ class OutputProcessor:
assert
finish_reason
is
not
None
assert
finish_reason
is
not
None
assert
req_state
.
stats
is
not
None
assert
req_state
.
stats
is
not
None
iteration_stats
.
update_from_finished_request
(
finish_reason
,
iteration_stats
.
update_from_finished_request
(
request_output
,
finish_reason
=
finish_reason
,
req_state
.
stats
)
num_prompt_tokens
=
len
(
req_state
.
prompt_token_ids
),
req_stats
=
req_state
.
stats
)
self
.
lora_states
.
finish_request
(
req_state
)
self
.
lora_states
.
finish_request
(
req_state
)
@
staticmethod
def
_make_request_output
(
request_state
:
RequestState
,
new_token_ids
:
list
[
int
],
finish_reason
:
Optional
[
FinishReason
],
stop_reason
:
Union
[
int
,
str
,
None
],
)
->
Optional
[
RequestOutput
]:
finished
=
finish_reason
is
not
None
output_kind
=
request_state
.
output_kind
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if
not
finished
and
(
request_state
.
is_prefilling
or
output_kind
==
RequestOutputKind
.
FINAL_ONLY
):
# Only the final output is required in FINAL_ONLY mode.
return
None
detokenizer
=
request_state
.
detokenizer
logprobs_processor
=
request_state
.
logprobs_processor
delta
=
output_kind
==
RequestOutputKind
.
DELTA
logprobs
=
logprobs_processor
.
logprobs
if
delta
:
if
logprobs
:
logprobs
=
logprobs
[
-
len
(
new_token_ids
):]
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs
=
logprobs_processor
.
pop_prompt_logprobs
()
else
:
prompt_logprobs
=
logprobs_processor
.
prompt_logprobs
request_output
=
RequestOutput
.
new
(
request_id
=
request_state
.
request_id
,
prompt
=
request_state
.
prompt
,
prompt_token_ids
=
request_state
.
prompt_token_ids
,
text
=
detokenizer
.
get_next_output_text
(
finished
,
delta
),
token_ids
=
new_token_ids
if
delta
else
detokenizer
.
output_token_ids
,
logprobs
=
logprobs
,
prompt_logprobs
=
prompt_logprobs
,
cumulative_logprob
=
logprobs_processor
.
cumulative_logprob
,
finished
=
finished
,
)
if
finished
:
completion_output
=
request_output
.
outputs
[
0
]
completion_output
.
finish_reason
=
str
(
finish_reason
)
completion_output
.
stop_reason
=
stop_reason
return
request_output
vllm/v1/engine/parallel_sampling.py
View file @
4167252e
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
AsyncGenerator
,
Mapping
from
copy
import
copy
from
copy
import
copy
from
typing
import
Optional
,
Protoco
l
,
Union
from
typing
import
Callable
,
Optiona
l
,
Union
from
vllm.inputs
import
PromptType
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.utils
import
merge_async_iterators
class
AsyncGenerateMethodType
(
Protocol
):
class
ParentRequest
:
def
__call__
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
class
SyncAddRequestMethodType
(
Protocol
):
def
__call__
(
self
,
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
)
->
None
:
...
class
ParallelSamplingRequest
:
"""Info, state & processing for parallel sampling request.
"""Info, state & processing for parallel sampling request.
Store parent request ID and sampling params.
Store parent request ID and sampling params.
Facilitate generating child request sampling params.
Facilitate generating child request sampling params.
Transform child request outputs into parent request
outputs.
When stream mode is disabled, then `self.request_output`
aggregates child request completions.
"""
"""
request_id
:
str
request_id
:
str
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
# To aggregate child completions when not streaming
output_aggregator
:
Optional
[
RequestOutput
]
# To efficiently obtain child sampling params
cached_child_sampling_params
:
Optional
[
SamplingParams
]
cached_child_sampling_params
:
Optional
[
SamplingParams
]
request_output
:
Optional
[
RequestOutput
]
num_finished_completions
:
int
def
__init__
(
self
,
request_id
:
str
,
def
__init__
(
self
,
request_id
:
str
,
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
output_aggregator
=
None
self
.
cached_child_sampling_params
=
None
self
.
cached_child_sampling_params
=
None
self
.
request_output
=
None
self
.
num_finished_completions
=
0
@
classmethod
def
from_params
(
cls
,
request_id
:
str
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
Optional
[
'ParentRequest'
]:
if
not
isinstance
(
params
,
SamplingParams
)
or
params
.
n
==
1
:
return
None
return
cls
(
request_id
,
params
)
def
_get_child_sampling_params
(
def
_get_child_sampling_params
(
self
,
self
,
...
@@ -96,47 +73,6 @@ class ParallelSamplingRequest:
...
@@ -96,47 +73,6 @@ class ParallelSamplingRequest:
child_sampling_params
.
seed
=
seed
+
index
child_sampling_params
.
seed
=
seed
+
index
return
child_sampling_params
return
child_sampling_params
def
_add_output
(
self
,
child_req_output
:
RequestOutput
,
index
:
int
,
)
->
None
:
"""Aggregate a parallel sampling child
request output.
Non-stream-mode (`output_kind == FINAL_ONLY`)
only. Inject correct parent request ID and
completion index.
Args:
child_req_output: a single request output
from a parallel sampling
child request.
index: index within `n` child
"""
self
.
num_finished_completions
+=
1
new_completion
=
child_req_output
.
outputs
[
0
]
new_completion
.
index
=
index
if
self
.
request_output
is
None
:
# Save the first request output; reinstate
# original request ID; metrics are not
# supported for parallel sampling
child_req_output
.
request_id
=
self
.
request_id
child_req_output
.
metrics
=
None
self
.
request_output
=
child_req_output
else
:
# Aggregate additional completion into request output
# Note: will be sorted by index later
self
.
request_output
.
outputs
.
append
(
new_completion
)
def
_get_final_request_output
(
self
)
->
RequestOutput
:
"""Invariant: parent completion outputs sorted by index"""
assert
self
.
request_output
is
not
None
self
.
request_output
.
finished
=
True
self
.
request_output
.
outputs
=
sorted
(
self
.
request_output
.
outputs
,
key
=
lambda
x
:
x
.
index
)
return
self
.
request_output
def
get_child_info
(
self
,
index
:
int
)
->
tuple
[
str
,
SamplingParams
]:
def
get_child_info
(
self
,
index
:
int
)
->
tuple
[
str
,
SamplingParams
]:
"""Get child request ID and sampling params.
"""Get child request ID and sampling params.
...
@@ -149,227 +85,35 @@ class ParallelSamplingRequest:
...
@@ -149,227 +85,35 @@ class ParallelSamplingRequest:
return
(
f
"
{
index
}
_
{
self
.
request_id
}
"
,
return
(
f
"
{
index
}
_
{
self
.
request_id
}
"
,
self
.
_get_child_sampling_params
(
index
))
self
.
_get_child_sampling_params
(
index
))
def
process_output
(
self
,
child_req_output
:
RequestOutput
,
index
:
int
,
)
->
Optional
[
RequestOutput
]:
"""Filter, aggregate and transform parallel sampling
child request outputs.
If the parent request has `stream=false`
(`output_kind == FINAL_ONLY`), each child will also have
`output_kind == FINAL_ONLY`. All child request outputs
must be aggregated into a single request output, with
multiple completions. This request output is only returned
once `n` completions are aggregated.
If the parent request has `stream=true`
(`output_kind == DELTA`), each child will also have
`output_kind == DELTA`. All child request outputs
must be streamed directly to the caller.
Args:
child_req_output: a single child request output
index: index within `n` child requests
Returns:
`None`, unless a processed request output is ready to
send back to the caller.
"""
if
self
.
output_kind
!=
RequestOutputKind
.
FINAL_ONLY
:
# stream=true: return child completions immediately
child_req_output
.
request_id
=
self
.
request_id
child_req_output
.
outputs
[
0
].
index
=
index
if
child_req_output
.
finished
:
# Parent request is complete if all child requests are
# complete.
self
.
num_finished_completions
+=
1
child_req_output
.
finished
=
(
self
.
num_finished_completions
==
self
.
n
)
return
child_req_output
# stream=false: aggregate child completions
self
.
_add_output
(
child_req_output
,
index
)
if
self
.
num_finished_completions
==
self
.
n
:
# Return aggregated request output after obtaining
# all completions
return
self
.
_get_final_request_output
()
return
None
async
def
wrap_child_async_generator
(
self
,
child_gen
:
AsyncGenerator
[
RequestOutput
,
None
],
index
:
int
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Output generator for a single parallel sampling
child request.
Each parallel sampling request triggers at
least two child requests. This generator
yields zero or more request outputs to
return to the caller, as they become
available.
Args:
child_gen: generator for child request
outputs.
index: index within the `n` child requests
Returns:
Yields zero or more request outputs to return
to the caller.
"""
async
for
out
in
child_gen
:
if
req_out
:
=
self
.
process_output
(
out
,
index
):
yield
req_out
@
property
@
property
def
n
(
self
)
->
int
:
def
n
(
self
)
->
int
:
return
self
.
sampling_params
.
n
return
self
.
sampling_params
.
n
@
property
def
make_request_output
(
def
output_kind
(
self
)
->
RequestOutputKind
:
return
self
.
sampling_params
.
output_kind
class
SyncParallelSamplingManager
:
def
__init__
(
self
):
# Parent req ID -> parent request manager
self
.
parent_reqs
:
dict
[
str
,
ParallelSamplingRequest
]
=
{}
# Child req ID -> (child req index, parent req ID)
self
.
child_reqs
:
dict
[
str
,
tuple
[
int
,
str
]]
=
{}
def
_register_parent_request
(
self
,
req
:
ParallelSamplingRequest
)
->
None
:
"""Register parallel sampling parent request."""
self
.
parent_reqs
[
req
.
request_id
]
=
req
def
_register_child_request
(
self
,
req_id
:
str
,
child_req_id
:
str
,
index
:
int
)
->
None
:
"""Register parallel sampling child request with parent.
Args:
req_id: parent request ID
child_req_id: child request ID
index: child request index within `n` child requests
"""
self
.
child_reqs
[
child_req_id
]
=
(
index
,
req_id
)
def
get_num_unfinished_requests
(
self
,
num_core_reqs
:
int
)
->
int
:
"""Get the number of unfinished requests, correcting for parallel
sampling.
Args:
num_core_reqs: The number of unfinished requests in the engine core.
Returns:
Number of unfinished requests, where each parallel sampling req
counts as 1
"""
return
num_core_reqs
+
len
(
self
.
parent_reqs
)
-
len
(
self
.
child_reqs
)
def
add_request_parallel_sampling
(
self
,
self
,
add_request
:
SyncAddRequestMethodType
,
final_only
:
bool
,
request_id
:
str
,
completion_output
:
CompletionOutput
,
prompt
:
PromptType
,
new_request_output
:
Callable
[[
str
],
RequestOutput
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
Optional
[
RequestOutput
]:
arrival_time
:
Optional
[
float
]
=
None
,
# Use an existing RequestOutput if we're aggregating
lora_request
:
Optional
[
LoRARequest
]
=
None
,
request_output
=
self
.
output_aggregator
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
"""Add sync parallel sampling request."""
req
=
ParallelSamplingRequest
(
request_id
,
params
)
self
.
_register_parent_request
(
req
)
# Add n child requests with unique request IDs & random seeds and n=1
for
idx
in
range
(
req
.
n
):
child_req_id
,
child_params
=
req
.
get_child_info
(
idx
)
self
.
_register_child_request
(
request_id
,
child_req_id
,
idx
)
add_request
(
request_id
=
child_req_id
,
prompt
=
prompt
,
params
=
child_params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
)
# type: ignore
def
step
(
self
,
outputs
:
list
[
RequestOutput
],
)
->
list
[
RequestOutput
]:
"""Build parallel sampling request outputs.
Extract child request outputs, aggregate them
into parent request output, and return parent
output when complete.
Do not modify `n=1` requests.
Args:
# Make new RequestOutput otherwise
outputs: step
request
output
s. Mix of child request
if
request
_
output
is
None
:
outputs & `n=1`
request
output
s.
request_output
=
new_
request
_
output
(
self
.
request_id
)
Return:
# Add a new completion
List of parallel sampling parent request outputs &
request_output
.
outputs
.
append
(
completion_output
)
unmodified `n=1` request outputs passed-thru from input.
"""
if
not
(
self
.
parent_reqs
and
outputs
):
# Return unmodified
return
outputs
agg_outputs
=
[]
for
output
in
outputs
:
req_id
=
output
.
request_id
if
child_req_entry
:
=
self
.
child_reqs
.
get
(
req_id
,
None
):
# For each parallel sampling child request output:
(
index
,
parent_req_id
)
=
child_req_entry
req
=
self
.
parent_reqs
[
parent_req_id
]
# Update parallel sampling request
if
out
:
=
req
.
process_output
(
output
,
index
):
# Return parent request output if complete;
# cleanup parent request bookkeeping.
agg_outputs
.
append
(
out
)
del
self
.
parent_reqs
[
parent_req_id
]
# Cleanup child request bookkeeping.
del
self
.
child_reqs
[
req_id
]
else
:
# Not a parallel sampling request output
agg_outputs
.
append
(
output
)
return
agg_outputs
# If not streaming, aggregate until all child requests complete
if
final_only
and
len
(
request_output
.
outputs
)
!=
self
.
n
:
self
.
output_aggregator
=
request_output
return
None
async
def
generate_parallel_sampling_async
(
# We're done aggregating
generate
:
AsyncGenerateMethodType
,
self
.
output_aggregator
=
None
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate completions for async parallel sampling requests."""
parent_req
=
ParallelSamplingRequest
(
request_id
,
sampling_params
)
# Aggregate generators for n child requests
gens
:
list
[
AsyncGenerator
[
RequestOutput
,
None
]]
=
[]
for
idx
in
range
(
parent_req
.
n
):
child_req_id
,
child_params
=
parent_req
.
get_child_info
(
idx
)
child_gen
=
generate
(
prompt
=
prompt
,
sampling_params
=
child_params
,
request_id
=
child_req_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
# type: ignore
gen
=
parent_req
.
wrap_child_async_generator
(
child_gen
,
idx
)
gens
.
append
(
gen
)
# Merge generators
# Parent completion output list must be sorted by index
async
for
_
,
out
in
merge_async_iterators
(
*
gens
):
request_output
.
outputs
=
sorted
(
request_output
.
outputs
,
yield
out
key
=
lambda
x
:
x
.
index
)
return
request_output
vllm/v1/metrics/stats.py
View file @
4167252e
...
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
...
@@ -5,7 +5,6 @@ from dataclasses import dataclass, field
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.outputs
import
RequestOutput
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.engine
import
EngineCoreEvent
,
EngineCoreOutput
,
FinishReason
from
vllm.v1.output_processor
import
RequestState
from
vllm.v1.output_processor
import
RequestState
...
@@ -150,7 +149,7 @@ class IterationStats:
...
@@ -150,7 +149,7 @@ class IterationStats:
self
.
num_preempted_reqs
+=
1
self
.
num_preempted_reqs
+=
1
def
update_from_finished_request
(
self
,
finish_reason
:
"FinishReason"
,
def
update_from_finished_request
(
self
,
finish_reason
:
"FinishReason"
,
request_output
:
"RequestOutput"
,
num_prompt_tokens
:
int
,
req_stats
:
RequestStateStats
):
req_stats
:
RequestStateStats
):
e2e_latency
=
self
.
_time_since
(
req_stats
.
arrival_time
)
e2e_latency
=
self
.
_time_since
(
req_stats
.
arrival_time
)
...
@@ -172,7 +171,7 @@ class IterationStats:
...
@@ -172,7 +171,7 @@ class IterationStats:
finished_req
=
\
finished_req
=
\
FinishedRequestStats
(
finish_reason
=
finish_reason
,
FinishedRequestStats
(
finish_reason
=
finish_reason
,
e2e_latency
=
e2e_latency
,
e2e_latency
=
e2e_latency
,
num_prompt_tokens
=
len
(
request_output
.
prompt_token
_ids
)
,
num_prompt_tokens
=
num_
prompt_token
s
,
num_generation_tokens
=
req_stats
.
num_generation_tokens
,
num_generation_tokens
=
req_stats
.
num_generation_tokens
,
queued_time
=
queued_time
,
queued_time
=
queued_time
,
prefill_time
=
prefill_time
,
prefill_time
=
prefill_time
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment