Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
76a5e132
Unverified
Commit
76a5e132
authored
Oct 21, 2024
by
youkaichao
Committed by
GitHub
Oct 22, 2024
Browse files
[core] move parallel sampling out from vllm core (#9302)
parent
ef7faad1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
222 additions
and
29 deletions
+222
-29
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+34
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+42
-10
vllm/outputs.py
vllm/outputs.py
+26
-17
vllm/sequence.py
vllm/sequence.py
+120
-2
No files found.
tests/entrypoints/openai/test_completion.py
View file @
76a5e132
...
...
@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert
""
.
join
(
chunks
)
==
single_output
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
,
"zephyr-lora"
,
"zephyr-pa"
],
)
async
def
test_parallel_streaming
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt
=
"What is an LLM?"
n
=
3
max_tokens
=
5
stream
=
await
client
.
completions
.
create
(
model
=
model_name
,
prompt
=
prompt
,
max_tokens
=
max_tokens
,
n
=
n
,
stream
=
True
)
chunks
:
List
[
List
[
str
]]
=
[[]
for
i
in
range
(
n
)]
finish_reason_count
=
0
async
for
chunk
in
stream
:
index
=
chunk
.
choices
[
0
].
index
text
=
chunk
.
choices
[
0
].
text
chunks
[
index
].
append
(
text
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
finish_reason_count
+=
1
assert
finish_reason_count
==
n
for
chunk
in
chunks
:
assert
len
(
chunk
)
==
max_tokens
print
(
""
.
join
(
chunk
))
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
...
...
vllm/engine/llm_engine.py
View file @
76a5e132
...
...
@@ -44,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceStatus
)
ParallelSampleSequenceGroup
,
Sequence
,
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
...
...
@@ -474,6 +476,8 @@ class LLMEngine:
),
))
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -642,7 +646,10 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
SequenceGroup
:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
...
...
@@ -696,6 +703,8 @@ class LLMEngine:
min_cost_scheduler
=
self
.
scheduler
[
costs
.
index
(
min
(
costs
))]
min_cost_scheduler
.
add_seq_group
(
seq_group
)
return
seq_group
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
...
...
@@ -711,7 +720,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
...
@
overload
...
...
@@ -725,7 +734,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
...
@
deprecate_kwargs
(
...
...
@@ -744,7 +753,7 @@ class LLMEngine:
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
...
...
@@ -788,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
prompt
=
prompt
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
inputs
=
inputs
,
)
return
None
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
...
...
@@ -818,7 +843,7 @@ class LLMEngine:
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
"mm_processor_kwargs"
)
self
.
_add_processed_request
(
return
self
.
_add_processed_request
(
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
params
=
params
,
...
...
@@ -1135,7 +1160,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
@@ -1175,7 +1202,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
@@ -1194,7 +1223,10 @@ class LLMEngine:
continue
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
,
)
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
...
...
vllm/outputs.py
View file @
76a5e132
import
time
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
SequenceGroup
,
SequenceGroupBase
,
SequenceStatus
)
@
dataclass
...
...
@@ -114,14 +114,28 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
use_cache
:
bool
)
->
Optional
[
"RequestOutput"
]:
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
,
use_cache
:
bool
,
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
)
->
Optional
[
"RequestOutput"
]:
finished
=
seq_group
.
is_finished
()
if
seq_group
.
request_id
in
seq_id_to_seq_group
:
group
:
SequenceGroupBase
=
seq_id_to_seq_group
[
seq_group
.
request_id
]
if
finished
:
group
.
finish_seq
(
seq_group
)
assembled_seq_group
=
group
.
maybe_assemble_group
(
seq_group
)
if
assembled_seq_group
is
None
:
return
None
return
cls
.
from_seq_group
(
assembled_seq_group
,
use_cache
,
seq_id_to_seq_group
)
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
return
None
...
...
@@ -136,15 +150,7 @@ class RequestOutput:
outputs
=
[],
finished
=
False
)
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
else
:
# Get the top-n sequences.
n
=
sampling_params
.
_real_n
or
sampling_params
.
n
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
top_n_seqs
=
seq_group
.
get_seqs
()
# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
...
...
@@ -208,7 +214,7 @@ class RequestOutput:
else
:
output
=
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
[
output_token_ids
]
top_n_
seqs
.
index
(
seq
),
output_text
,
[
output_token_ids
]
if
isinstance
(
output_token_ids
,
int
)
else
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
...
...
@@ -309,10 +315,13 @@ class EmbeddingRequestOutput:
class
RequestOutputFactory
:
@
staticmethod
def
create
(
seq_group
:
SequenceGroup
,
use_cache
:
bool
=
False
):
def
create
(
seq_group
:
SequenceGroup
,
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
],
use_cache
:
bool
=
False
):
# Determine the type based on a condition, for example:
if
hasattr
(
seq_group
,
'embeddings'
)
and
seq_group
.
embeddings
is
not
None
:
return
EmbeddingRequestOutput
.
from_seq_group
(
seq_group
)
else
:
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
)
return
RequestOutput
.
from_seq_group
(
seq_group
,
use_cache
,
seq_id_to_seq_group
)
vllm/sequence.py
View file @
76a5e132
...
...
@@ -4,7 +4,7 @@ import enum
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
,
field
from
functools
import
cached_property
,
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -17,7 +17,7 @@ from vllm.inputs.parse import is_encoder_decoder_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
if
TYPE_CHECKING
:
...
...
@@ -1401,3 +1401,121 @@ class ExecuteModelRequest(
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
)
@
dataclass
class
SequenceGroupBase
:
group_id
:
str
# the original request id before splitting
assembled_seq_group
:
Optional
[
SequenceGroup
]
=
None
# seq id to a unique index inside this group
seq_id_to_index
:
Dict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
# seq ids to be finished
to_be_finished
:
Dict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
# seq id to finished sequences
finished_reqs
:
Dict
[
str
,
SequenceGroup
]
=
field
(
default_factory
=
dict
)
streaming
:
bool
=
False
output_produced
:
bool
=
False
@
staticmethod
def
add_request
(
request_id
:
str
,
engine
,
params
,
*
args
,
**
kwargs
):
"""When we are ready to add a request with request_id and params
into the engine, we can split the request into multiple requests.
"""
raise
NotImplementedError
def
finish_seq
(
self
,
seq
:
SequenceGroup
):
"""The sequence `seq` finishes, we should record the information.
"""
del
self
.
to_be_finished
[
seq
.
request_id
]
self
.
finished_reqs
[
seq
.
request_id
]
=
seq
def
maybe_assemble_group
(
self
,
seq_group
:
SequenceGroup
)
->
Optional
[
SequenceGroup
]:
"""Assemble the sequence group, for producing the final
output, or adding request in the engine again.
"""
raise
NotImplementedError
class
ParallelSampleSequenceGroup
(
SequenceGroupBase
):
@
staticmethod
def
add_request
(
request_id
:
str
,
engine
,
params
,
**
kwargs
):
original_params
=
params
params
=
copy
.
deepcopy
(
original_params
)
params
.
n
=
1
group
=
ParallelSampleSequenceGroup
(
request_id
)
seqs
=
[]
for
i
in
range
(
original_params
.
n
):
request_id_i
=
f
"
{
request_id
}
_parallel_sample_
{
i
}
"
group
.
seq_id_to_index
[
request_id_i
]
=
i
seq_group
=
engine
.
add_request
(
request_id_i
,
params
=
params
,
**
kwargs
,
)
# type: ignore
assert
seq_group
is
not
None
engine
.
seq_id_to_seq_group
[
request_id_i
]
=
group
group
.
to_be_finished
[
request_id_i
]
=
seq_group
seqs
.
append
(
seq_group
.
seqs
[
0
])
# for parallel sampling, the `assembled_seq_group` is always
# available, since we have all the sequences ready, and they
# will not change.
group
.
assembled_seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
seqs
,
arrival_time
=
seq_group
.
arrival_time
,
sampling_params
=
original_params
,
lora_request
=
seq_group
.
lora_request
,
embeddings
=
seq_group
.
embeddings
,
pooling_params
=
seq_group
.
pooling_params
,
encoder_seq
=
seq_group
.
encoder_seq
,
trace_headers
=
seq_group
.
trace_headers
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
priority
=
seq_group
.
priority
,
)
group
.
streaming
=
params
.
output_kind
==
RequestOutputKind
.
DELTA
group
.
output_produced
=
False
def
maybe_assemble_group
(
self
,
seq_group
:
SequenceGroup
)
->
Optional
[
SequenceGroup
]:
# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
if
self
.
streaming
:
if
self
.
seq_id_to_index
[
seq_group
.
request_id
]
==
0
:
return
self
.
assembled_seq_group
return
None
# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# rest of the time
if
len
(
self
.
to_be_finished
)
>
0
:
return
None
assert
self
.
assembled_seq_group
is
not
None
params
=
self
.
assembled_seq_group
.
sampling_params
assert
isinstance
(
params
,
SamplingParams
)
if
not
self
.
output_produced
:
self
.
output_produced
=
True
if
params
.
_real_n
is
not
None
:
# Get the top-n sequences.
n
=
params
.
_real_n
or
params
.
n
seqs
=
self
.
assembled_seq_group
.
seqs
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
top_n_seqs
=
sorted_seqs
[:
n
]
self
.
assembled_seq_group
.
seqs
=
top_n_seqs
return
self
.
assembled_seq_group
if
self
.
output_produced
:
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment