Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ae5ed1e
Unverified
Commit
5ae5ed1e
authored
May 29, 2024
by
Cyrus Leung
Committed by
GitHub
May 28, 2024
Browse files
[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
290f4ada
Changes
43
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
83 additions
and
40 deletions
+83
-40
vllm/outputs.py
vllm/outputs.py
+15
-27
vllm/sequence.py
vllm/sequence.py
+26
-12
vllm/utils.py
vllm/utils.py
+42
-1
No files found.
vllm/outputs.py
View file @
5ae5ed1e
import
time
import
time
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
...
@@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup
,
SequenceStatus
)
SequenceGroup
,
SequenceStatus
)
@
dataclass
class
CompletionOutput
:
class
CompletionOutput
:
"""The output data of one completion output of a request.
"""The output data of one completion output of a request.
...
@@ -24,25 +26,14 @@ class CompletionOutput:
...
@@ -24,25 +26,14 @@ class CompletionOutput:
lora_request: The LoRA request that was used to generate the output.
lora_request: The LoRA request that was used to generate the output.
"""
"""
def
__init__
(
index
:
int
self
,
text
:
str
index
:
int
,
token_ids
:
List
[
int
]
text
:
str
,
cumulative_logprob
:
float
token_ids
:
List
[
int
],
logprobs
:
Optional
[
SampleLogprobs
]
cumulative_logprob
:
float
,
finish_reason
:
Optional
[
str
]
=
None
logprobs
:
Optional
[
SampleLogprobs
],
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
finish_reason
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
self
.
index
=
index
self
.
text
=
text
self
.
token_ids
=
token_ids
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
finish_reason
=
finish_reason
self
.
stop_reason
=
stop_reason
self
.
lora_request
=
lora_request
def
finished
(
self
)
->
bool
:
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
return
self
.
finish_reason
is
not
None
...
@@ -57,6 +48,7 @@ class CompletionOutput:
...
@@ -57,6 +48,7 @@ class CompletionOutput:
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
f
"stop_reason=
{
self
.
stop_reason
}
)"
)
@
dataclass
class
EmbeddingOutput
:
class
EmbeddingOutput
:
"""The output data of one completion output of a request.
"""The output data of one completion output of a request.
...
@@ -65,15 +57,11 @@ class EmbeddingOutput:
...
@@ -65,15 +57,11 @@ class EmbeddingOutput:
length of vector depends on the model as listed in the embedding guide.
length of vector depends on the model as listed in the embedding guide.
"""
"""
def
__init__
(
embedding
:
List
[
float
]
self
,
embedding
:
List
[
float
],
)
->
None
:
self
.
embedding
=
embedding
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingOutput("
return
(
f
"EmbeddingOutput("
f
"embedding=
{
len
(
self
.
embedding
)
}
"
)
f
"embedding=
{
len
(
self
.
embedding
)
}
)
"
)
class
RequestOutput
:
class
RequestOutput
:
...
@@ -93,7 +81,7 @@ class RequestOutput:
...
@@ -93,7 +81,7 @@ class RequestOutput:
def
__init__
(
def
__init__
(
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
str
,
prompt
:
Optional
[
str
]
,
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
...
@@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
...
@@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
finished (bool): A flag indicating whether the embedding is completed.
finished (bool): A flag indicating whether the embedding is completed.
"""
"""
def
__init__
(
self
,
request_id
:
str
,
outputs
:
'
EmbeddingOutput
'
,
def
__init__
(
self
,
request_id
:
str
,
outputs
:
"
EmbeddingOutput
"
,
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
prompt_token_ids
:
List
[
int
],
finished
:
bool
):
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
prompt_token_ids
=
prompt_token_ids
self
.
prompt_token_ids
=
prompt_token_ids
...
...
vllm/sequence.py
View file @
5ae5ed1e
...
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
...
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
vllm.block
import
LogicalTokenBlock
from
vllm.block
import
LogicalTokenBlock
from
vllm.inputs
import
LLMInputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
...
@@ -210,8 +211,7 @@ class Sequence:
...
@@ -210,8 +211,7 @@ class Sequence:
Args:
Args:
seq_id: The ID of the sequence.
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
inputs: The inputs of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
block size used by the block manager and cache engine.
lora_request: LoRA request.
lora_request: LoRA request.
...
@@ -220,25 +220,24 @@ class Sequence:
...
@@ -220,25 +220,24 @@ class Sequence:
def
__init__
(
def
__init__
(
self
,
self
,
seq_id
:
int
,
seq_id
:
int
,
prompt
:
str
,
inputs
:
LLMInputs
,
prompt_token_ids
:
List
[
int
],
block_size
:
int
,
block_size
:
int
,
eos_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
seq_id
=
seq_id
self
.
seq_id
=
seq_id
self
.
prompt
=
prompt
self
.
inputs
=
inputs
self
.
block_size
=
block_size
self
.
block_size
=
block_size
self
.
eos_token_id
=
eos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
data
:
SequenceData
=
SequenceData
(
prompt_token_ids
)
self
.
data
=
SequenceData
(
self
.
prompt_token_ids
)
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
self
.
output_text
=
""
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
self
.
logical_token_blocks
:
List
[
LogicalTokenBlock
]
=
[]
# Initialize the logical token blocks with the prompt token ids.
# Initialize the logical token blocks with the prompt token ids.
self
.
_append_tokens_to_blocks
(
prompt_token_ids
)
self
.
_append_tokens_to_blocks
(
self
.
prompt_token_ids
)
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
...
@@ -248,6 +247,18 @@ class Sequence:
...
@@ -248,6 +247,18 @@ class Sequence:
# Input + output tokens
# Input + output tokens
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
self
.
tokens
:
Optional
[
List
[
str
]]
=
None
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
return
self
.
inputs
[
"prompt"
]
@
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
return
self
.
inputs
[
"prompt_token_ids"
]
@
property
def
multi_modal_data
(
self
)
->
Optional
[
"MultiModalData"
]:
return
self
.
inputs
[
"multi_modal_data"
]
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
return
self
.
lora_request
.
lora_int_id
if
self
.
lora_request
else
0
...
@@ -415,7 +426,6 @@ class SequenceGroup:
...
@@ -415,7 +426,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
pooling_params: The pooling parameters used to generate the pooling
...
@@ -429,7 +439,6 @@ class SequenceGroup:
...
@@ -429,7 +439,6 @@ class SequenceGroup:
arrival_time
:
float
,
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
multi_modal_data
:
Optional
[
MultiModalData
]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -444,12 +453,11 @@ class SequenceGroup:
...
@@ -444,12 +453,11 @@ class SequenceGroup:
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
state
=
SequenceGroupState
()
self
.
multi_modal_data
=
multi_modal_data
self
.
embeddings
=
embeddings
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
pooling_params
=
pooling_params
@
property
@
property
def
prompt
(
self
)
->
str
:
def
prompt
(
self
)
->
Optional
[
str
]
:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
prompt
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
prompt
...
@@ -458,7 +466,13 @@ class SequenceGroup:
...
@@ -458,7 +466,13 @@ class SequenceGroup:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
data
.
prompt_token_ids
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
prompt_token_ids
@
property
def
multi_modal_data
(
self
)
->
Optional
[
MultiModalData
]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return
next
(
iter
(
self
.
seqs_dict
.
values
())).
multi_modal_data
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
...
vllm/utils.py
View file @
5ae5ed1e
...
@@ -11,7 +11,7 @@ import threading
...
@@ -11,7 +11,7 @@ import threading
import
uuid
import
uuid
import
warnings
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
AsyncIterator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
Any
,
AsyncIterator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
Hashable
,
List
,
Optional
,
OrderedDict
,
Tuple
,
TypeVar
,
...
@@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
...
@@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
filename
)
filename
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
log_path
),
exist_ok
=
True
)
enable_trace_function_call
(
log_path
)
enable_trace_function_call
(
log_path
)
def
identity
(
value
:
T
)
->
T
:
return
value
F
=
TypeVar
(
'F'
,
bound
=
Callable
[...,
Any
])
def
deprecate_kwargs
(
*
kws
:
str
,
is_deprecated
:
Union
[
bool
,
Callable
[[],
bool
]]
=
True
,
additional_message
:
Optional
[
str
]
=
None
)
->
Callable
[[
F
],
F
]:
deprecated_kws
=
set
(
kws
)
if
not
callable
(
is_deprecated
):
is_deprecated
=
partial
(
identity
,
is_deprecated
)
def
wrapper
(
fn
:
F
)
->
F
:
@
wraps
(
fn
)
def
inner
(
*
args
,
**
kwargs
):
if
is_deprecated
():
deprecated_kwargs
=
kwargs
.
keys
()
&
deprecated_kws
if
deprecated_kwargs
:
msg
=
(
f
"The keyword arguments
{
deprecated_kwargs
}
are "
"deprecated and will be removed in a future update."
)
if
additional_message
is
not
None
:
msg
+=
f
"
{
additional_message
}
"
warnings
.
warn
(
DeprecationWarning
(
msg
),
stacklevel
=
3
,
# The inner function takes up one level
)
return
fn
(
*
args
,
**
kwargs
)
return
inner
# type: ignore
return
wrapper
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment