Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f5d3acd4
Unverified
Commit
f5d3acd4
authored
Mar 12, 2025
by
Nick Hill
Committed by
GitHub
Mar 12, 2025
Browse files
[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
916836bb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
137 additions
and
113 deletions
+137
-113
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+50
-6
tests/v1/entrypoints/openai/test_completion.py
tests/v1/entrypoints/openai/test_completion.py
+15
-6
vllm/outputs.py
vllm/outputs.py
+18
-46
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+1
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-1
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+30
-19
vllm/v1/engine/parallel_sampling.py
vllm/v1/engine/parallel_sampling.py
+22
-33
No files found.
tests/v1/engine/test_async_llm.py
View file @
f5d3acd4
...
@@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM,
...
@@ -46,6 +46,7 @@ async def generate(engine: AsyncLLM,
prompt
:
PromptType
,
prompt
:
PromptType
,
output_kind
:
RequestOutputKind
,
output_kind
:
RequestOutputKind
,
max_tokens
:
int
,
max_tokens
:
int
,
n
:
int
=
1
,
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
tuple
[
int
,
str
]:
prompt_logprobs
:
Optional
[
int
]
=
None
)
->
tuple
[
int
,
str
]:
# Ensure generate doesn't complete too fast for cancellation test.
# Ensure generate doesn't complete too fast for cancellation test.
await
asyncio
.
sleep
(
0.2
)
await
asyncio
.
sleep
(
0.2
)
...
@@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM,
...
@@ -54,13 +55,15 @@ async def generate(engine: AsyncLLM,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
sampling_params
=
SamplingParams
(
max_tokens
=
max_tokens
,
ignore_eos
=
True
,
ignore_eos
=
True
,
output_kind
=
output_kind
,
output_kind
=
output_kind
,
temperature
=
0
,
temperature
=
0.5
,
seed
=
33
,
n
=
n
,
prompt_logprobs
=
prompt_logprobs
)
prompt_logprobs
=
prompt_logprobs
)
async
for
out
in
engine
.
generate
(
request_id
=
request_id
,
async
for
out
in
engine
.
generate
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
,
sampling_params
=
sampling_params
):
sampling_params
=
sampling_params
):
num_tokens
=
len
(
out
.
outputs
[
0
].
token_id
s
)
num_tokens
=
sum
(
len
(
out
put
.
token_ids
)
for
output
in
out
.
output
s
)
if
output_kind
==
RequestOutputKind
.
DELTA
:
if
output_kind
==
RequestOutputKind
.
DELTA
:
count
+=
num_tokens
count
+=
num_tokens
else
:
else
:
...
@@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
...
@@ -136,17 +139,22 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
NUM_REQUESTS
=
100
NUM_REQUESTS
=
100
NUM_EXPECTED_TOKENS
=
100
NUM_EXPECTED_TOKENS
=
100
NUM_EXPECTED_TOKENS_LONG
=
50000
REQUEST_IDS_TO_ABORT
=
range
(
1
,
100
,
10
)
REQUEST_IDS_TO_ABORT
=
range
(
1
,
100
,
10
)
PARALLEL_SAMPLE_REQ_IDS
=
range
(
1
,
100
,
15
)
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
request_ids
=
[
f
"request-
{
i
}
"
for
i
in
range
(
NUM_REQUESTS
)]
# Create concurrent requests.
# Create concurrent requests.
tasks
:
list
[
asyncio
.
Task
]
=
[]
tasks
:
list
[
asyncio
.
Task
]
=
[]
for
request_id
in
request_ids
:
for
idx
,
request_id
in
enumerate
(
request_ids
):
max_tokens
=
NUM_EXPECTED_TOKENS_LONG
if
(
idx
in
REQUEST_IDS_TO_ABORT
)
else
NUM_EXPECTED_TOKENS
n
=
3
if
idx
in
PARALLEL_SAMPLE_REQ_IDS
else
1
tasks
.
append
(
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
generate
(
engine
,
request_id
,
prompt
,
output_kind
,
NUM_EXPECTED_TOKENS
)))
max_tokens
,
n
)))
# API server cancels requests when they disconnect.
# API server cancels requests when they disconnect.
for
idx
in
REQUEST_IDS_TO_ABORT
:
for
idx
in
REQUEST_IDS_TO_ABORT
:
...
@@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
...
@@ -162,10 +170,13 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
else
:
else
:
# Otherwise, make sure the request was not impacted.
# Otherwise, make sure the request was not impacted.
num_generated_tokens
,
request_id
=
await
task
num_generated_tokens
,
request_id
=
await
task
assert
num_generated_tokens
==
NUM_EXPECTED_TOKENS
,
(
n
=
3
if
idx
in
PARALLEL_SAMPLE_REQ_IDS
else
1
expected_tokens
=
NUM_EXPECTED_TOKENS
*
n
assert
num_generated_tokens
==
expected_tokens
,
(
f
"
{
request_id
}
generated
{
num_generated_tokens
}
but "
f
"
{
request_id
}
generated
{
num_generated_tokens
}
but "
f
"expected
{
NUM_EXPECTED_TOKENS
}
"
)
f
"expected
{
expected_tokens
}
"
)
# Make sure all aborted requests were really aborted.
assert
not
engine
.
output_processor
.
has_unfinished_requests
()
assert
not
engine
.
output_processor
.
has_unfinished_requests
()
# Confirm we can do another generation.
# Confirm we can do another generation.
...
@@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
...
@@ -176,3 +187,36 @@ async def test_abort(monkeypatch, output_kind: RequestOutputKind,
num_generated_tokens
,
request_id
=
await
task
num_generated_tokens
,
request_id
=
await
task
assert
num_generated_tokens
==
NUM_EXPECTED_TOKENS
assert
num_generated_tokens
==
NUM_EXPECTED_TOKENS
assert
not
engine
.
output_processor
.
has_unfinished_requests
()
assert
not
engine
.
output_processor
.
has_unfinished_requests
()
@
pytest
.
mark
.
parametrize
(
"n"
,
[
1
,
3
])
@
pytest
.
mark
.
parametrize
(
"engine_args_and_prompt"
,
[(
TEXT_ENGINE_ARGS
,
TEXT_PROMPT
),
(
VISION_ENGINE_ARGS
,
VISION_PROMPT
)])
@
pytest
.
mark
.
asyncio
async
def
test_finished_flag
(
monkeypatch
,
n
:
int
,
engine_args_and_prompt
:
tuple
[
AsyncEngineArgs
,
PromptType
]):
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
engine_args
,
prompt
=
engine_args_and_prompt
engine
=
AsyncLLM
.
from_engine_args
(
engine_args
)
after
.
callback
(
engine
.
shutdown
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
output_kind
=
RequestOutputKind
.
DELTA
,
temperature
=
1.0
,
seed
=
33
,
n
=
n
)
outputs
=
[
out
async
for
out
in
engine
.
generate
(
request_id
=
"request-33"
,
prompt
=
prompt
,
sampling_params
=
sampling_params
)
]
# Assert only the last output has the finished flag set
assert
all
(
not
out
.
finished
for
out
in
outputs
[:
-
1
])
assert
outputs
[
-
1
].
finished
tests/v1/entrypoints/openai/test_completion.py
View file @
f5d3acd4
...
@@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
...
@@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
prompt
=
"What is an LLM?"
prompt
=
"What is an LLM?"
n
=
3
n
=
3
max_tokens
=
5
max_tokens
=
5
0
# we want some to finish earlier than others
# High temperature to maximize chance of unique completions.
# High temperature to maximize chance of unique completions.
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
completion
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
n
=
n
,
n
=
n
,
temperature
=
0.95
,
temperature
=
1.0
,
stream
=
False
,
stream
=
False
,
logprobs
=
0
,
seed
=
42
)
seed
=
42
)
# Assert `n` completions
# Assert `n` completions
...
@@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
...
@@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
assert
num_completions
==
n
,
(
assert
num_completions
==
n
,
(
f
"Num completions
{
num_completions
}
but expected
{
n
}
."
)
f
"Num completions
{
num_completions
}
but expected
{
n
}
."
)
completion_repeats
:
dict
[
str
,
int
]
=
{}
completion_repeats
:
dict
[
str
,
int
]
=
{}
output_token_lengths
=
set
()
for
idx
,
choice
in
enumerate
(
completion
.
choices
):
for
idx
,
choice
in
enumerate
(
completion
.
choices
):
# Assert correct completion index & some finish reason.
# Assert correct completion index & some finish reason.
assert
choice
.
index
==
idx
,
(
assert
choice
.
index
==
idx
,
(
...
@@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
...
@@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
"None finish_reason is invalid."
)
"None finish_reason is invalid."
)
text
=
choice
.
text
text
=
choice
.
text
completion_repeats
[
text
]
=
completion_repeats
.
get
(
text
,
0
)
+
1
completion_repeats
[
text
]
=
completion_repeats
.
get
(
text
,
0
)
+
1
output_token_lengths
.
add
(
len
(
choice
.
logprobs
.
tokens
))
# Assert subrequests finished at different times
assert
len
(
output_token_lengths
)
>
1
# Assert `n` unique completions
# Assert `n` unique completions
num_unique
=
len
(
completion_repeats
)
num_unique
=
len
(
completion_repeats
)
if
num_unique
!=
n
:
if
num_unique
!=
n
:
...
@@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
...
@@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
prompt
=
"What is an LLM?"
prompt
=
"What is an LLM?"
n
=
3
n
=
3
max_tokens
=
5
max_tokens
=
5
0
# we want some to finish earlier than others
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
n
=
n
,
n
=
n
,
temperature
=
0.95
,
temperature
=
1.0
,
stream
=
True
,
stream
=
True
,
seed
=
42
)
seed
=
42
)
chunks
:
list
[
list
[
str
]]
=
[[]
for
i
in
range
(
n
)]
chunks
:
list
[
list
[
str
]]
=
[[]
for
_
in
range
(
n
)]
finish_reason_count
=
0
finish_reason_count
=
0
async
for
chunk
in
stream
:
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
index
=
chunk
.
choices
[
0
].
index
...
@@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
...
@@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
assert
finish_reason_count
==
n
,
(
assert
finish_reason_count
==
n
,
(
f
"Expected
{
n
}
completions with valid indices and finish_reason."
)
f
"Expected
{
n
}
completions with valid indices and finish_reason."
)
completion_repeats
:
dict
[
str
,
int
]
=
{}
completion_repeats
:
dict
[
str
,
int
]
=
{}
chunk_lengths
=
set
()
for
chunk
in
chunks
:
for
chunk
in
chunks
:
chunk_len
=
len
(
chunk
)
chunk_len
=
len
(
chunk
)
# Assert correct number of completion tokens
# Assert correct number of completion tokens
assert
chunk_len
==
max_tokens
,
(
chunk_lengths
.
add
(
chunk_len
)
assert
chunk_len
<=
max_tokens
,
(
f
"max_tokens=
{
max_tokens
}
but chunk len is
{
chunk_len
}
."
)
f
"max_tokens=
{
max_tokens
}
but chunk len is
{
chunk_len
}
."
)
text
=
""
.
join
(
chunk
)
text
=
""
.
join
(
chunk
)
completion_repeats
[
text
]
=
completion_repeats
.
get
(
text
,
0
)
+
1
completion_repeats
[
text
]
=
completion_repeats
.
get
(
text
,
0
)
+
1
print
(
text
)
print
(
text
)
# Assert subrequests finished at different times
assert
len
(
chunk_lengths
)
>
1
# Assert `n` unique completions
# Assert `n` unique completions
num_unique
=
len
(
completion_repeats
)
num_unique
=
len
(
completion_repeats
)
if
num_unique
!=
n
:
if
num_unique
!=
n
:
...
...
vllm/outputs.py
View file @
f5d3acd4
...
@@ -134,57 +134,29 @@ class RequestOutput:
...
@@ -134,57 +134,29 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
num_cached_tokens
=
num_cached_tokens
self
.
num_cached_tokens
=
num_cached_tokens
@
classmethod
def
new
(
cls
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
list
[
int
]],
text
:
str
,
token_ids
:
list
[
int
],
logprobs
:
Optional
[
SampleLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
cumulative_logprob
:
Optional
[
float
],
finished
:
bool
=
False
,
)
->
"RequestOutput"
:
"""Initialize a new RequestOutput object."""
# TODO: Support `n` > 1.
completion_output
=
CompletionOutput
(
index
=
0
,
text
=
text
,
token_ids
=
token_ids
,
cumulative_logprob
=
cumulative_logprob
,
logprobs
=
logprobs
)
return
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
outputs
=
[
completion_output
],
finished
=
finished
,
)
def
add
(
self
,
next_output
:
"RequestOutput"
)
->
None
:
def
add
(
self
,
next_output
:
"RequestOutput"
)
->
None
:
"""Merge subsequent RequestOutput into this one"""
"""Merge subsequent RequestOutput into this one"""
self
.
prompt
=
next_output
.
prompt
self
.
prompt_token_ids
=
next_output
.
prompt_token_ids
self
.
prompt_logprobs
=
next_output
.
prompt_logprobs
self
.
finished
|=
next_output
.
finished
self
.
finished
|=
next_output
.
finished
#TODO assuming n == 1 for now
for
next_completion
in
next_output
.
outputs
:
completion
=
self
.
outputs
[
0
]
for
completion
in
self
.
outputs
:
next_completion
=
next_output
.
outputs
[
0
]
if
completion
.
index
==
next_completion
.
index
:
completion
.
text
+=
next_completion
.
text
# Merge outputs with same index
if
not
isinstance
(
completion
.
token_ids
,
MutableSequence
):
completion
.
text
+=
next_completion
.
text
completion
.
token_ids
=
list
(
completion
.
token_ids
)
if
not
isinstance
(
completion
.
token_ids
,
MutableSequence
):
completion
.
token_ids
.
extend
(
next_completion
.
token_ids
)
completion
.
token_ids
=
list
(
completion
.
token_ids
)
if
next_completion
.
logprobs
:
completion
.
token_ids
.
extend
(
next_completion
.
token_ids
)
assert
completion
.
logprobs
is
not
None
if
next_completion
.
logprobs
:
completion
.
logprobs
.
extend
(
next_completion
.
logprobs
)
assert
completion
.
logprobs
is
not
None
completion
.
cumulative_logprob
=
next_completion
.
cumulative_logprob
completion
.
logprobs
.
extend
(
next_completion
.
logprobs
)
completion
.
cumulative_logprob
=
(
next_completion
.
cumulative_logprob
)
completion
.
finish_reason
=
next_completion
.
finish_reason
completion
.
stop_reason
=
next_completion
.
stop_reason
break
else
:
self
.
outputs
.
append
(
next_completion
)
@
classmethod
@
classmethod
def
from_seq_group
(
def
from_seq_group
(
...
...
vllm/v1/engine/async_llm.py
View file @
f5d3acd4
...
@@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
...
@@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort RequestId in OutputProcessor and EngineCore."""
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids
=
[
request_id
]
request_ids
=
self
.
output_processor
.
abort_requests
((
request_id
,
))
await
self
.
engine_core
.
abort_requests_async
(
request_ids
)
await
self
.
engine_core
.
abort_requests_async
(
request_ids
)
self
.
output_processor
.
abort_requests
(
request_ids
)
if
self
.
log_requests
:
if
self
.
log_requests
:
logger
.
info
(
"Aborted request %s."
,
request_id
)
logger
.
info
(
"Aborted request %s."
,
request_id
)
...
...
vllm/v1/engine/llm_engine.py
View file @
f5d3acd4
...
@@ -137,8 +137,8 @@ class LLMEngine:
...
@@ -137,8 +137,8 @@ class LLMEngine:
def
abort_request
(
self
,
request_ids
:
list
[
str
])
->
None
:
def
abort_request
(
self
,
request_ids
:
list
[
str
])
->
None
:
"""Remove request_ids from EngineCore and Detokenizer."""
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids
=
self
.
output_processor
.
abort_requests
(
request_ids
)
self
.
engine_core
.
abort_requests
(
request_ids
)
self
.
engine_core
.
abort_requests
(
request_ids
)
self
.
output_processor
.
abort_requests
(
request_ids
)
def
add_request
(
def
add_request
(
self
,
self
,
...
...
vllm/v1/engine/output_processor.py
View file @
f5d3acd4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
asyncio
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
...
@@ -102,8 +103,7 @@ class RequestState:
...
@@ -102,8 +103,7 @@ class RequestState:
)
->
Optional
[
RequestOutput
]:
)
->
Optional
[
RequestOutput
]:
finished
=
finish_reason
is
not
None
finished
=
finish_reason
is
not
None
output_kind
=
self
.
output_kind
final_only
=
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
final_only
=
output_kind
==
RequestOutputKind
.
FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
# does not stream partial prefills.
...
@@ -111,24 +111,24 @@ class RequestState:
...
@@ -111,24 +111,24 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode.
# Only the final output is required in FINAL_ONLY mode.
return
None
return
None
def
new_request_output
(
request_id
:
str
)
->
RequestOutput
:
return
self
.
_new_request_output
(
request_id
,
finished
)
completion_output
=
self
.
_new_completion_output
(
completion_output
=
self
.
_new_completion_output
(
new_token_ids
,
finish_reason
,
stop_reason
)
new_token_ids
,
finish_reason
,
stop_reason
)
if
self
.
parent_req
is
not
None
:
request_id
=
self
.
request_id
return
self
.
parent_req
.
make_request_output
(
final_only
,
if
self
.
parent_req
is
None
:
completion_output
,
outputs
=
[
completion_output
]
new_request_output
)
else
:
request_id
,
outputs
,
finished
=
self
.
parent_req
.
get_outputs
(
request_id
,
completion_output
)
if
not
outputs
:
return
None
request_output
=
new_request_output
(
self
.
request_id
)
return
self
.
_new_request_output
(
request_id
,
outputs
,
finished
)
request_output
.
outputs
.
append
(
completion_output
)
return
request_output
def
_new_request_output
(
def
_new_request_output
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
outputs
:
list
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
)
->
RequestOutput
:
)
->
RequestOutput
:
...
@@ -143,7 +143,7 @@ class RequestState:
...
@@ -143,7 +143,7 @@ class RequestState:
prompt
=
self
.
prompt
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
self
.
prompt_token_ids
,
prompt_token_ids
=
self
.
prompt_token_ids
,
prompt_logprobs
=
prompt_logprobs
,
prompt_logprobs
=
prompt_logprobs
,
outputs
=
[]
,
outputs
=
outputs
,
finished
=
finished
,
finished
=
finished
,
)
)
...
@@ -188,6 +188,7 @@ class OutputProcessor:
...
@@ -188,6 +188,7 @@ class OutputProcessor:
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
request_states
:
dict
[
str
,
RequestState
]
=
{}
self
.
request_states
:
dict
[
str
,
RequestState
]
=
{}
self
.
parent_requests
:
dict
[
str
,
ParentRequest
]
=
{}
self
.
lora_states
=
LoRARequestStates
()
self
.
lora_states
=
LoRARequestStates
()
def
get_num_unfinished_requests
(
self
):
def
get_num_unfinished_requests
(
self
):
...
@@ -198,14 +199,20 @@ class OutputProcessor:
...
@@ -198,14 +199,20 @@ class OutputProcessor:
def
abort_requests
(
def
abort_requests
(
self
,
self
,
request_ids
:
list
[
str
],
request_ids
:
Iterable
[
str
],
)
->
None
:
)
->
list
[
str
]:
request_ids_to_abort
=
[]
for
request_id
in
request_ids
:
for
request_id
in
request_ids
:
req_state
=
self
.
request_states
.
pop
(
request_id
,
None
)
req_state
=
self
.
request_states
.
pop
(
request_id
,
None
)
if
req_state
is
not
None
:
if
req_state
is
not
None
:
self
.
lora_states
.
abort_request
(
req_state
)
self
.
lora_states
.
abort_request
(
req_state
)
if
req_state
.
parent_req
is
not
None
:
request_ids_to_abort
.
append
(
request_id
)
req_state
.
parent_req
.
finish_child_request
(
request_id
)
else
:
parent
=
self
.
parent_requests
.
pop
(
request_id
,
None
)
if
parent
and
parent
.
child_requests
:
self
.
abort_requests
(
parent
.
child_requests
)
request_ids_to_abort
.
extend
(
parent
.
child_requests
)
return
request_ids_to_abort
def
add_request
(
def
add_request
(
self
,
self
,
...
@@ -227,6 +234,8 @@ class OutputProcessor:
...
@@ -227,6 +234,8 @@ class OutputProcessor:
log_stats
=
self
.
log_stats
)
log_stats
=
self
.
log_stats
)
self
.
request_states
[
request_id
]
=
req_state
self
.
request_states
[
request_id
]
=
req_state
self
.
lora_states
.
add_request
(
req_state
)
self
.
lora_states
.
add_request
(
req_state
)
if
parent_req
:
self
.
parent_requests
[
parent_req
.
request_id
]
=
parent_req
def
process_outputs
(
def
process_outputs
(
self
,
self
,
...
@@ -314,12 +323,14 @@ class OutputProcessor:
...
@@ -314,12 +323,14 @@ class OutputProcessor:
# Free completed requests.
# Free completed requests.
if
finish_reason
is
not
None
:
if
finish_reason
is
not
None
:
self
.
request_states
.
pop
(
req_id
)
self
.
request_states
.
pop
(
req_id
)
# Remove parent request if applicable.
parent_req
=
req_state
.
parent_req
if
parent_req
and
not
parent_req
.
child_requests
:
self
.
parent_requests
.
pop
(
parent_req
.
request_id
,
None
)
if
not
engine_core_output
.
finished
:
if
not
engine_core_output
.
finished
:
# If req not finished in EngineCore, but Detokenizer
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
# detected stop string, abort needed in EngineCore.
reqs_to_abort
.
append
(
req_id
)
reqs_to_abort
.
append
(
req_id
)
if
req_state
.
parent_req
is
not
None
:
req_state
.
parent_req
.
finish_child_request
(
req_id
)
# Track per-request stats
# Track per-request stats
self
.
_update_stats_from_finished
(
req_state
,
finish_reason
,
self
.
_update_stats_from_finished
(
req_state
,
finish_reason
,
...
...
vllm/v1/engine/parallel_sampling.py
View file @
f5d3acd4
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
copy
import
copy
from
copy
import
copy
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Optional
,
Union
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.metrics.stats
import
IterationStats
...
@@ -23,7 +23,7 @@ class ParentRequest:
...
@@ -23,7 +23,7 @@ class ParentRequest:
child_requests
:
set
[
str
]
child_requests
:
set
[
str
]
# To aggregate child completions when not streaming
# To aggregate child completions when not streaming
output_aggregator
:
Optional
[
Request
Output
]
output_aggregator
:
list
[
Completion
Output
]
# To find the max number of generated tokens across all children
# To find the max number of generated tokens across all children
max_num_generation_tokens
:
int
max_num_generation_tokens
:
int
...
@@ -37,7 +37,9 @@ class ParentRequest:
...
@@ -37,7 +37,9 @@ class ParentRequest:
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
child_requests
=
set
()
self
.
child_requests
=
set
()
self
.
output_aggregator
=
None
self
.
output_aggregator
=
[
None
]
*
sampling_params
.
n
if
(
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
)
else
[]
self
.
max_num_generation_tokens
=
0
self
.
max_num_generation_tokens
=
0
self
.
cached_child_sampling_params
=
None
self
.
cached_child_sampling_params
=
None
...
@@ -93,43 +95,30 @@ class ParentRequest:
...
@@ -93,43 +95,30 @@ class ParentRequest:
"""
"""
child_req_id
=
f
"
{
index
}
_
{
self
.
request_id
}
"
child_req_id
=
f
"
{
index
}
_
{
self
.
request_id
}
"
self
.
child_requests
.
add
(
child_req_id
)
self
.
child_requests
.
add
(
child_req_id
)
return
(
child_req_id
,
self
.
_get_child_sampling_params
(
index
))
return
child_req_id
,
self
.
_get_child_sampling_params
(
index
)
def
finish_child_request
(
self
,
req_id
:
str
):
self
.
child_requests
.
remove
(
req_id
)
@
property
@
property
def
n
(
self
)
->
int
:
def
n
(
self
)
->
int
:
return
self
.
sampling_params
.
n
return
self
.
sampling_params
.
n
def
make_reques
t_output
(
def
ge
t_output
s
(
self
,
self
,
final_only
:
bool
,
child_request_id
:
str
,
completion_output
:
CompletionOutput
,
completion_output
:
CompletionOutput
,
new_request_output
:
Callable
[[
str
],
RequestOutput
],
)
->
tuple
[
str
,
list
[
CompletionOutput
],
bool
]:
)
->
Optional
[
RequestOutput
]:
if
completion_output
.
finished
():
# Use an existing RequestOutput if we're aggregating
self
.
child_requests
.
remove
(
child_request_id
)
request_output
=
self
.
output_aggregator
# Make new RequestOutput otherwise
if
request_output
is
None
:
request_output
=
new_request_output
(
self
.
request_id
)
# Add a new completion
request_output
.
outputs
.
append
(
completion_output
)
# If not streaming, aggregate until all child requests complete
if
self
.
sampling_params
.
output_kind
!=
RequestOutputKind
.
FINAL_ONLY
:
if
final_only
and
len
(
request_output
.
outputs
)
!=
self
.
n
:
# If streaming, just return the current output.
self
.
output
_aggregator
=
request
_output
output
s
=
[
completion
_output
]
return
None
else
:
# If not streaming, aggregate the n final outputs.
# We're done aggregating
self
.
output_aggregator
[
completion_output
.
index
]
=
completion_output
self
.
output_aggregator
=
None
outputs
=
[]
if
self
.
child_requests
else
self
.
output_aggregator
# Parent completion output list must be sorted by index
finished
=
not
self
.
child_requests
request_output
.
outputs
=
sorted
(
request_output
.
outputs
,
return
self
.
request_id
,
outputs
,
finished
key
=
lambda
x
:
x
.
index
)
return
request_output
def
observe_num_generation_tokens
(
self
,
num_generation_tokens
:
int
):
def
observe_num_generation_tokens
(
self
,
num_generation_tokens
:
int
):
self
.
max_num_generation_tokens
=
max
(
num_generation_tokens
,
self
.
max_num_generation_tokens
=
max
(
num_generation_tokens
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment