Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eeec9e33
Unverified
Commit
eeec9e33
authored
Dec 13, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 13, 2024
Browse files
[Frontend] Separate pooling APIs in offline inference (#11129)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f93bf2b1
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
22 deletions
+18
-22
vllm/sequence.py
vllm/sequence.py
+18
-22
No files found.
vllm/sequence.py
View file @
eeec9e33
...
...
@@ -617,10 +617,9 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
embeddings: The embeddings vectors of the prompt of the sequence group
for a pooling model.
pooling_params: The pooling parameters used to generate the pooling
pooling_params: The parameters used to generate the pooler
for a pooling model.
pooled_data: The extracted hidden states from a pooling model.
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
...
...
@@ -635,8 +634,8 @@ class SequenceGroup:
arrival_time
:
float
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
embeddings
:
Optional
[
List
[
float
]]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
pooled_data
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
@@ -658,8 +657,8 @@ class SequenceGroup:
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
pooled_data
=
pooled_data
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
...
...
@@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
__metaclass__
=
SequenceGroupOutput
"""The model output associated with a completion sequence group."""
__metaclass__
=
SequenceGroupOutput
samples
:
List
[
SequenceOutput
]
# Prompt logprob for each prompt query token.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
...
...
@@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput(
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
class
Embedd
ingSequenceGroupOutput
(
class
Pool
ingSequenceGroupOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
):
"""The model output associated with a
n embedd
ing sequence group."""
"""The model output associated with a
pool
ing sequence group."""
__metaclass__
=
SequenceGroupOutput
embeddings
:
List
[
int
]
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data
:
Any
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingSequenceGroupOutput("
f
"embeddings_shape=
{
len
(
self
.
embeddings
)
}
)"
)
return
f
"PoolingSequenceGroupOutput(data=
{
self
.
data
}
"
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
Embedd
ingSequenceGroupOutput
):
if
not
isinstance
(
other
,
Pool
ingSequenceGroupOutput
):
raise
NotImplementedError
()
return
self
.
embeddings
==
other
.
embeddings
return
self
.
data
==
other
.
data
# cannot use msgspec.Struct here because Dynamo does not support it
...
...
@@ -1085,7 +1085,7 @@ class IntermediateTensors:
elif
isinstance
(
key
,
slice
):
return
self
.
__class__
({
k
:
v
[
key
]
for
k
,
v
in
self
.
tensors
.
items
()})
def
__setitem__
(
self
,
key
:
str
,
value
):
def
__setitem__
(
self
,
key
:
str
,
value
:
torch
.
Tensor
):
self
.
tensors
[
key
]
=
value
def
__len__
(
self
):
...
...
@@ -1103,16 +1103,12 @@ class PoolerOutput(
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs
:
List
[
EmbeddingSequenceGroupOutput
]
# lazy import to avoid circular import
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
outputs
:
List
[
PoolingSequenceGroupOutput
]
def
__getitem__
(
self
,
idx
:
int
)
->
Embedd
ingSequenceGroupOutput
:
def
__getitem__
(
self
,
idx
:
int
)
->
Pool
ingSequenceGroupOutput
:
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
def
__setitem__
(
self
,
idx
:
int
,
value
:
PoolingSequenceGroupOutput
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
...
...
@@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
arrival_time
=
seq_group
.
arrival_time
,
sampling_params
=
original_params
,
lora_request
=
seq_group
.
lora_request
,
embeddings
=
seq_group
.
embeddings
,
pooling_params
=
seq_group
.
pooling_params
,
pooled_data
=
seq_group
.
pooled_data
,
encoder_seq
=
seq_group
.
encoder_seq
,
trace_headers
=
seq_group
.
trace_headers
,
prompt_adapter_request
=
seq_group
.
prompt_adapter_request
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment