Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83449a5f
Unverified
Commit
83449a5f
authored
Feb 03, 2026
by
Cyrus Leung
Committed by
GitHub
Feb 03, 2026
Browse files
[Refactor] Clean up pooling serial utils (#33665)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
dad2d6a5
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
407 additions
and
322 deletions
+407
-322
examples/pooling/embed/embedding_requests_base64_online.py
examples/pooling/embed/embedding_requests_base64_online.py
+2
-6
examples/pooling/embed/embedding_requests_bytes_online.py
examples/pooling/embed/embedding_requests_bytes_online.py
+4
-5
tests/entrypoints/pooling/embed/test_online.py
tests/entrypoints/pooling/embed/test_online.py
+7
-9
tests/entrypoints/pooling/pooling/test_online.py
tests/entrypoints/pooling/pooling/test_online.py
+6
-8
tests/utils_/test_serial_utils.py
tests/utils_/test_serial_utils.py
+5
-3
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+111
-67
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+107
-67
vllm/entrypoints/pooling/utils.py
vllm/entrypoints/pooling/utils.py
+124
-0
vllm/utils/serial_utils.py
vllm/utils/serial_utils.py
+41
-157
No files found.
examples/pooling/embed/embedding_requests_base64_online.py
View file @
83449a5f
...
...
@@ -12,11 +12,7 @@ import base64
import
requests
import
torch
from
vllm.utils.serial_utils
import
(
EMBED_DTYPE_TO_TORCH_DTYPE
,
ENDIANNESS
,
binary2tensor
,
)
from
vllm.utils.serial_utils
import
EMBED_DTYPES
,
ENDIANNESS
,
binary2tensor
def
post_http_request
(
prompt
:
dict
,
api_url
:
str
)
->
requests
.
Response
:
...
...
@@ -45,7 +41,7 @@ def main(args):
]
*
2
# The OpenAI client does not support the embed_dtype and endianness parameters.
for
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
prompt
=
{
"model"
:
model
,
...
...
examples/pooling/embed/embedding_requests_bytes_online.py
View file @
83449a5f
...
...
@@ -12,13 +12,12 @@ import json
import
requests
import
torch
from
vllm.utils.serial_utils
import
(
EMBED_DTYPE_TO_TORCH_DTYPE
,
ENDIANNESS
,
from
vllm.entrypoints.pooling.utils
import
(
MetadataItem
,
build_metadata_items
,
decode_pooling_output
,
)
from
vllm.utils.serial_utils
import
EMBED_DTYPES
,
ENDIANNESS
def
post_http_request
(
prompt
:
dict
,
api_url
:
str
)
->
requests
.
Response
:
...
...
@@ -51,7 +50,7 @@ def main(args):
# The OpenAI client does not support the bytes encoding_format.
# The OpenAI client does not support the embed_dtype and endianness parameters.
for
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
prompt
=
{
"model"
:
model
,
...
...
@@ -74,7 +73,7 @@ def main(args):
# The vllm server always sorts the returned embeddings in the order of input. So
# returning metadata is not necessary. You can set encoding_format to bytes_only
# to let the server not return metadata.
for
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
prompt
=
{
"model"
:
model
,
...
...
tests/entrypoints/pooling/embed/test_online.py
View file @
83449a5f
...
...
@@ -17,16 +17,14 @@ from tests.models.utils import check_embeddings_close
from
tests.utils
import
RemoteOpenAIServer
from
vllm.entrypoints.pooling.embed.protocol
import
EmbeddingResponse
from
vllm.entrypoints.pooling.pooling.protocol
import
PoolingResponse
from
vllm.platforms
import
current_platform
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.serial_utils
import
(
EMBED_DTYPE_TO_TORCH_DTYPE
,
ENDIANNESS
,
from
vllm.entrypoints.pooling.utils
import
(
MetadataItem
,
binary2tensor
,
build_metadata_items
,
decode_pooling_output
,
)
from
vllm.platforms
import
current_platform
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.serial_utils
import
EMBED_DTYPES
,
ENDIANNESS
,
binary2tensor
MODEL_NAME
=
"intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE
=
"""{% for message in messages %}{{message['role'] + ': ' + message['content'] + '
\\
n'}}{% endfor %}"""
# noqa: E501
...
...
@@ -535,7 +533,7 @@ async def test_base64_embed_dtype_and_endianness(
)
float_data
=
[
d
.
embedding
for
d
in
responses_float
.
data
]
for
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_base64
=
requests
.
post
(
server
.
url_for
(
"/v1/embeddings"
),
...
...
@@ -574,7 +572,7 @@ async def test_bytes_embed_dtype_and_endianness(
)
float_data
=
[
d
.
embedding
for
d
in
responses_float
.
data
]
for
embed_dtype
in
list
(
EMBED_DTYPE
_TO_TORCH_DTYPE
.
keys
())
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_bytes
=
requests
.
post
(
server
.
url_for
(
"/v1/embeddings"
),
...
...
@@ -618,7 +616,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data
=
[
d
.
embedding
for
d
in
responses_float
.
data
]
embedding_size
=
len
(
float_data
[
0
])
for
embed_dtype
in
list
(
EMBED_DTYPE
_TO_TORCH_DTYPE
.
keys
())
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_bytes
=
requests
.
post
(
server
.
url_for
(
"/v1/embeddings"
),
...
...
tests/entrypoints/pooling/pooling/test_online.py
View file @
83449a5f
...
...
@@ -12,15 +12,13 @@ import torch
from
tests.models.utils
import
check_embeddings_close
from
tests.utils
import
RemoteOpenAIServer
from
vllm.entrypoints.pooling.pooling.protocol
import
PoolingResponse
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.serial_utils
import
(
EMBED_DTYPE_TO_TORCH_DTYPE
,
ENDIANNESS
,
from
vllm.entrypoints.pooling.utils
import
(
MetadataItem
,
binary2tensor
,
build_metadata_items
,
decode_pooling_output
,
)
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.serial_utils
import
EMBED_DTYPES
,
ENDIANNESS
,
binary2tensor
MODEL_NAME
=
"internlm/internlm2-1_8b-reward"
DUMMY_CHAT_TEMPLATE
=
"""{% for message in messages %}{{message['role'] + ': ' + message['content'] + '
\\
n'}}{% endfor %}"""
# noqa: E501
...
...
@@ -342,7 +340,7 @@ async def test_base64_embed_dtype_and_endianness(
responses_float
=
PoolingResponse
.
model_validate
(
float_response
.
json
())
float_data
=
[
np
.
array
(
d
.
data
).
squeeze
(
-
1
).
tolist
()
for
d
in
responses_float
.
data
]
for
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_base64
=
requests
.
post
(
url
,
...
...
@@ -389,7 +387,7 @@ async def test_bytes_embed_dtype_and_endianness(
responses_float
=
PoolingResponse
.
model_validate
(
float_response
.
json
())
float_data
=
[
np
.
array
(
d
.
data
).
squeeze
(
-
1
).
tolist
()
for
d
in
responses_float
.
data
]
for
embed_dtype
in
list
(
EMBED_DTYPE
_TO_TORCH_DTYPE
.
keys
())
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_bytes
=
requests
.
post
(
url
,
...
...
@@ -438,7 +436,7 @@ async def test_bytes_only_embed_dtype_and_endianness(
float_data
=
[
np
.
array
(
d
.
data
).
squeeze
(
-
1
).
tolist
()
for
d
in
responses_float
.
data
]
n_tokens
=
responses_float
.
usage
.
prompt_tokens
//
len
(
input_texts
)
for
embed_dtype
in
list
(
EMBED_DTYPE
_TO_TORCH_DTYPE
.
keys
())
:
for
embed_dtype
in
EMBED_DTYPE
S
:
for
endianness
in
ENDIANNESS
:
responses_bytes
=
requests
.
post
(
url
,
...
...
tests/utils_/test_serial_utils.py
View file @
83449a5f
...
...
@@ -5,17 +5,19 @@ import torch
from
tests.models.utils
import
check_embeddings_close
from
vllm.utils.serial_utils
import
(
EMBED_DTYPE
_TO_TORCH_DTYPE
,
EMBED_DTYPE
S
,
ENDIANNESS
,
EmbedDType
,
Endianness
,
binary2tensor
,
tensor2binary
,
)
@
pytest
.
mark
.
parametrize
(
"endianness"
,
ENDIANNESS
)
@
pytest
.
mark
.
parametrize
(
"embed_dtype"
,
EMBED_DTYPE
_TO_TORCH_DTYPE
.
keys
())
@
pytest
.
mark
.
parametrize
(
"embed_dtype"
,
EMBED_DTYPE
S
.
keys
())
@
torch
.
inference_mode
()
def
test_encode_and_decode
(
embed_dtype
:
str
,
endianness
:
str
):
def
test_encode_and_decode
(
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
):
for
i
in
range
(
10
):
tensor
=
torch
.
rand
(
2
,
3
,
5
,
7
,
11
,
13
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
shape
=
tensor
.
shape
...
...
vllm/entrypoints/pooling/embed/serving.py
View file @
83449a5f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
AsyncGenerator
,
Mapping
from
typing
import
Any
,
Final
,
TypeAlias
from
collections.abc
import
AsyncGenerator
,
Callable
,
Mapping
from
functools
import
partial
from
typing
import
Any
,
Final
,
Literal
,
TypeAlias
,
cast
import
torch
from
fastapi
import
Request
...
...
@@ -22,16 +23,18 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse
,
EmbeddingResponseData
,
)
from
vllm.entrypoints.pooling.utils
import
(
encode_pooling_bytes
,
encode_pooling_output_base64
,
encode_pooling_output_float
,
)
from
vllm.inputs.data
import
EmbedsPrompt
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.collection_utils
import
chunk_list
from
vllm.utils.serial_utils
import
(
encode_pooling_bytes
,
encode_pooling_output
,
)
from
vllm.utils.serial_utils
import
EmbedDType
,
Endianness
logger
=
init_logger
(
__name__
)
...
...
@@ -113,79 +116,120 @@ class OpenAIServingEmbedding(OpenAIServing):
logger
.
exception
(
"Error in preprocessing prompt inputs"
)
return
self
.
create_error_response
(
str
(
e
))
def
request_output_to_embed_json_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"float"
,
"base64"
],
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
EmbeddingResponse
:
encode_fn
=
cast
(
Callable
[[
PoolingRequestOutput
],
list
[
float
]
|
str
],
(
encode_pooling_output_float
if
encoding_format
==
"float"
else
partial
(
encode_pooling_output_base64
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
),
)
items
:
list
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
item
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
encode_fn
(
final_res
),
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
def
request_output_to_embed_bytes_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"bytes"
,
"bytes_only"
],
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
EmbeddingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
headers
=
(
None
if
encoding_format
==
"bytes_only"
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
request_id
,
"created"
:
created_time
,
"model"
:
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
)
}
)
return
EmbeddingBytesResponse
(
content
=
content
,
headers
=
headers
)
def
_build_response
(
self
,
ctx
:
EmbeddingServeContext
,
)
->
EmbeddingResponse
|
EmbeddingBytesResponse
|
ErrorResponse
:
final_res_batch_checked
=
ctx
.
final_res_batch
encoding_format
=
ctx
.
request
.
encoding_format
embed_dtype
=
ctx
.
request
.
embed_dtype
endianness
=
ctx
.
request
.
endianness
def
encode_float_base64
():
items
:
list
[
EmbeddingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch_checked
):
item
=
EmbeddingResponseData
(
index
=
idx
,
embedding
=
encode_pooling_output
(
final_res
,
encoding_format
=
encoding_format
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
),
)
prompt_token_ids
=
final_res
.
prompt_token_ids
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
EmbeddingResponse
(
id
=
ctx
.
request_id
,
created
=
ctx
.
created_time
,
model
=
ctx
.
model_name
,
data
=
items
,
usage
=
usage
,
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
self
.
request_output_to_embed_json_response
(
ctx
.
final_res_batch
,
ctx
.
request_id
,
ctx
.
created_time
,
ctx
.
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
def
encode_bytes
(
bytes_only
:
bool
)
->
EmbeddingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch_checked
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
if
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
self
.
request_output_to_embed_bytes_response
(
ctx
.
final_res_batch
,
ctx
.
request_id
,
ctx
.
created_time
,
ctx
.
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
headers
=
(
None
if
bytes_only
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
ctx
.
request_id
,
"created"
:
ctx
.
created_time
,
"model"
:
ctx
.
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
)
}
)
return
EmbeddingBytesResponse
(
content
=
content
,
headers
=
headers
)
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
encode_float_base64
()
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
encode_bytes
(
bytes_only
=
encoding_format
==
"bytes_only"
)
else
:
assert_never
(
encoding_format
)
assert_never
(
encoding_format
)
def
_get_max_position_embeddings
(
self
)
->
int
:
"""Get the model's effective maximum sequence length for chunking."""
...
...
vllm/entrypoints/pooling/pooling/serving.py
View file @
83449a5f
...
...
@@ -4,8 +4,9 @@
import
asyncio
import
json
import
time
from
collections.abc
import
AsyncGenerator
,
Sequence
from
typing
import
Any
,
Final
,
cast
from
collections.abc
import
AsyncGenerator
,
Callable
,
Sequence
from
functools
import
partial
from
typing
import
Any
,
Final
,
Literal
,
cast
import
jinja2
from
fastapi
import
Request
...
...
@@ -27,17 +28,16 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse
,
PoolingResponseData
,
)
from
vllm.entrypoints.pooling.utils
import
(
encode_pooling_bytes
,
encode_pooling_output_base64
,
encode_pooling_output_float
,
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.tasks
import
PoolingTask
,
SupportedTask
from
vllm.utils.async_utils
import
merge_async_iterators
from
vllm.utils.serial_utils
import
(
EmbedDType
,
EncodingFormat
,
Endianness
,
encode_pooling_bytes
,
encode_pooling_output
,
)
from
vllm.utils.serial_utils
import
EmbedDType
,
EncodingFormat
,
Endianness
logger
=
init_logger
(
__name__
)
...
...
@@ -256,79 +256,119 @@ class OpenAIServingPooling(OpenAIServing):
return
response
def
request_output_to_pooling_response
(
def
request_output_to_pooling_
json_
response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
EncodingFormat
,
encoding_format
:
Literal
[
"float"
,
"base64"
]
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
PoolingResponse
|
PoolingBytesResponse
:
def
encode_float_base64
():
items
:
list
[
PoolingResponseData
]
=
[]
num_prompt_tokens
=
0
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
item
=
PoolingResponseData
(
index
=
idx
,
data
=
encode_pooling_output
(
final_res
,
encoding_format
=
encoding_format
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
),
)
->
PoolingResponse
:
encode_fn
=
cast
(
Callable
[[
PoolingRequestOutput
],
list
[
float
]
|
str
],
(
encode_pooling_output_float
if
encoding_format
==
"float"
else
partial
(
encode_pooling_output_base64
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
prompt_token_ids
=
final_res
.
prompt_token_ids
),
)
items
.
append
(
item
)
num_prompt_tokens
+
=
len
(
prompt_token_ids
)
items
:
list
[
PoolingResponseData
]
=
[]
num_prompt_tokens
=
0
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
for
idx
,
final_res
in
enumerate
(
final_res_batch
):
item
=
PoolingResponseData
(
index
=
idx
,
data
=
encode_fn
(
final_res
),
)
prompt_token_ids
=
final_res
.
prompt_token_ids
return
PoolingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
items
.
append
(
item
)
num_prompt_tokens
+=
len
(
prompt_token_ids
)
def
encode_bytes
(
bytes_only
:
bool
)
->
PoolingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
headers
=
(
None
if
bytes_only
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
request_id
,
"created"
:
created_time
,
"model"
:
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
)
}
return
PoolingResponse
(
id
=
request_id
,
created
=
created_time
,
model
=
model_name
,
data
=
items
,
usage
=
usage
,
)
def
request_output_to_pooling_bytes_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
Literal
[
"bytes"
,
"bytes_only"
],
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
PoolingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
headers
=
(
None
if
encoding_format
==
"bytes_only"
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
request_id
,
"created"
:
created_time
,
"model"
:
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
)
}
)
return
PoolingBytesResponse
(
content
=
content
,
headers
=
headers
)
def
request_output_to_pooling_response
(
self
,
final_res_batch
:
list
[
PoolingRequestOutput
],
request_id
:
str
,
created_time
:
int
,
model_name
:
str
,
encoding_format
:
EncodingFormat
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
PoolingResponse
|
PoolingBytesResponse
:
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
self
.
request_output_to_pooling_json_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
return
PoolingBytesResponse
(
content
=
content
,
headers
=
headers
,
if
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
self
.
request_output_to_pooling_bytes_response
(
final_res_batch
,
request_id
,
created_time
,
model_name
,
encoding_format
,
embed_dtype
,
endianness
,
)
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
encode_float_base64
()
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
encode_bytes
(
bytes_only
=
encoding_format
==
"bytes_only"
)
else
:
assert_never
(
encoding_format
)
assert_never
(
encoding_format
)
vllm/entrypoints/pooling/utils.py
0 → 100644
View file @
83449a5f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
import
pybase64
import
torch
from
vllm.outputs
import
PoolingRequestOutput
from
vllm.utils.serial_utils
import
(
EMBED_DTYPES
,
EmbedDType
,
Endianness
,
binary2tensor
,
tensor2binary
,
)
@
dataclass
class
MetadataItem
:
index
:
int
embed_dtype
:
EmbedDType
endianness
:
Endianness
start
:
int
end
:
int
shape
:
tuple
[
int
,
...]
def
build_metadata_items
(
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
shape
:
tuple
[
int
,
...],
n_request
:
int
,
)
->
list
[
MetadataItem
]:
n_bytes
=
EMBED_DTYPES
[
embed_dtype
].
nbytes
size
=
math
.
prod
(
shape
)
return
[
MetadataItem
(
index
=
i
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
start
=
i
*
size
*
n_bytes
,
end
=
(
i
+
1
)
*
size
*
n_bytes
,
shape
=
shape
,
)
for
i
in
range
(
n_request
)
]
def
encode_pooling_output_float
(
output
:
PoolingRequestOutput
)
->
list
[
float
]:
return
output
.
outputs
.
data
.
tolist
()
def
encode_pooling_output_binary
(
output
:
PoolingRequestOutput
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
bytes
:
return
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
def
encode_pooling_output_base64
(
output
:
PoolingRequestOutput
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
str
:
embedding_bytes
=
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
return
pybase64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
def
encode_pooling_bytes
(
pooling_outputs
:
list
[
PoolingRequestOutput
],
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
tuple
[
list
[
bytes
],
list
[
dict
[
str
,
Any
]],
dict
[
str
,
Any
]]:
num_prompt_tokens
=
0
items
:
list
[
dict
[
str
,
Any
]]
=
[]
body
:
list
[
bytes
]
=
[]
offset
=
0
for
idx
,
output
in
enumerate
(
pooling_outputs
):
binary
=
tensor2binary
(
tensor
=
output
.
outputs
.
data
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
size
=
len
(
binary
)
# Dictionary form of MetadataItem
item
=
dict
(
index
=
idx
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
start
=
offset
,
end
=
offset
+
size
,
shape
=
output
.
outputs
.
data
.
shape
,
)
body
.
append
(
binary
)
items
.
append
(
item
)
prompt_token_ids
=
output
.
prompt_token_ids
num_prompt_tokens
+=
len
(
prompt_token_ids
)
offset
+=
size
# Dictionary form of UsageInfo
usage
=
dict
(
prompt_tokens
=
num_prompt_tokens
,
total_tokens
=
num_prompt_tokens
,
)
return
body
,
items
,
usage
def
decode_pooling_output
(
items
:
list
[
MetadataItem
],
body
:
bytes
)
->
list
[
torch
.
Tensor
]:
return
[
binary2tensor
(
body
[
item
.
start
:
item
.
end
],
item
.
shape
,
item
.
embed_dtype
,
item
.
endianness
,
)
for
item
in
sorted
(
items
,
key
=
lambda
x
:
x
.
index
)
]
vllm/utils/serial_utils.py
View file @
83449a5f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
base64
import
io
import
math
import
sys
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
from
typing
import
Literal
,
get_args
import
numpy
as
np
import
numpy.typing
as
npt
import
pybase64
import
torch
from
typing_extensions
import
assert_never
if
TYPE_CHECKING
:
from
vllm
import
PoolingRequestOutput
else
:
PoolingRequestOutput
=
Any
sys_byteorder
=
sys
.
byteorder
EMBED_DTYPE_TO_TORCH_DTYPE
=
{
"float32"
:
torch
.
float32
,
"float16"
:
torch
.
float16
,
"bfloat16"
:
torch
.
bfloat16
,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3"
:
torch
.
float8_e4m3fn
,
"fp8_e5m2"
:
torch
.
float8_e5m2
,
}
EMBED_DTYPE_TO_N_BYTES
=
{
"float32"
:
4
,
"float16"
:
2
,
"bfloat16"
:
2
,
"fp8_e4m3"
:
1
,
"fp8_e5m2"
:
1
,
}
@
dataclass
(
frozen
=
True
)
class
DTypeInfo
:
torch_dtype
:
torch
.
dtype
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW
=
{
"float32"
:
torch
.
float32
,
"float16"
:
torch
.
float16
,
# numpy does not support bfloat16 and fp8
"bfloat16"
:
torch
.
float16
,
"fp8_e4m3"
:
torch
.
uint8
,
"fp8_e5m2"
:
torch
.
uint8
,
}
torch_view_dtype
:
torch
.
dtype
numpy_view_dtype
:
npt
.
DTypeLike
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
=
{
"float32"
:
np
.
float32
,
"float16"
:
np
.
float16
,
# numpy does not support bfloat16 and fp8
"bfloat16"
:
np
.
float16
,
"fp8_e4m3"
:
np
.
uint8
,
"fp8_e5m2"
:
np
.
uint8
,
}
@
property
def
nbytes
(
self
)
->
int
:
return
self
.
torch_dtype
.
itemsize
ENDIANNESS
=
[
"native"
,
"big"
,
"little"
]
EmbedDType
=
Literal
[
"float32"
,
"float16"
,
"bfloat16"
,
"fp8_e4m3"
,
"fp8_e5m2"
]
Endianness
=
Literal
[
"native"
,
"big"
,
"little"
]
EncodingFormat
=
Literal
[
"float"
,
"base64"
,
"bytes"
,
"bytes_only"
]
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
# NOTE: numpy does not support bfloat16 and fp8
EMBED_DTYPES
:
Mapping
[
EmbedDType
,
DTypeInfo
]
=
{
"float32"
:
DTypeInfo
(
torch
.
float32
,
torch
.
float32
,
np
.
float32
),
"float16"
:
DTypeInfo
(
torch
.
float16
,
torch
.
float16
,
np
.
float16
),
"bfloat16"
:
DTypeInfo
(
torch
.
bfloat16
,
torch
.
float16
,
np
.
float16
),
"fp8_e4m3"
:
DTypeInfo
(
torch
.
float8_e4m3fn
,
torch
.
uint8
,
np
.
uint8
),
"fp8_e5m2"
:
DTypeInfo
(
torch
.
float8_e5m2
,
torch
.
uint8
,
np
.
uint8
),
}
ENDIANNESS
:
tuple
[
Endianness
,
...]
=
get_args
(
Endianness
)
def
tensor2base64
(
x
:
torch
.
Tensor
)
->
str
:
with
io
.
BytesIO
()
as
buf
:
...
...
@@ -71,21 +51,26 @@ def tensor2base64(x: torch.Tensor) -> str:
buf
.
seek
(
0
)
binary_data
=
buf
.
read
()
return
base64
.
b64encode
(
binary_data
).
decode
(
"utf-8"
)
return
py
base64
.
b64encode
(
binary_data
).
decode
(
"utf-8"
)
def
tensor2binary
(
tensor
:
torch
.
Tensor
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
tensor
:
torch
.
Tensor
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
bytes
:
assert
isinstance
(
tensor
,
torch
.
Tensor
)
assert
embed_dtype
in
EMBED_DTYPE
_TO_TORCH_DTYPE
assert
embed_dtype
in
EMBED_DTYPE
S
assert
endianness
in
ENDIANNESS
torch_dtype
=
EMBED_DTYPE_TO_TORCH_DTYPE
[
embed_dtype
]
torch_view_dtype
=
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW
[
embed_dtype
]
dtype_info
=
EMBED_DTYPES
[
embed_dtype
]
np_array
=
(
tensor
.
to
(
torch_dtype
).
flatten
().
contiguous
().
view
(
torch_view_dtype
).
numpy
()
tensor
.
to
(
dtype_info
.
torch_dtype
)
.
flatten
()
.
contiguous
()
.
view
(
dtype_info
.
torch_view_dtype
)
.
numpy
()
)
if
endianness
!=
"native"
and
endianness
!=
sys_byteorder
:
...
...
@@ -100,115 +85,14 @@ def binary2tensor(
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
torch
.
Tensor
:
assert
embed_dtype
in
EMBED_DTYPE_TO_TORCH_DTYPE
assert
embed_dtype
in
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
assert
embed_dtype
in
EMBED_DTYPES
assert
endianness
in
ENDIANNESS
torch_dtype
=
EMBED_DTYPE_TO_TORCH_DTYPE
[
embed_dtype
]
np_dtype
=
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
[
embed_dtype
]
dtype_info
=
EMBED_DTYPES
[
embed_dtype
]
np_array
=
np
.
frombuffer
(
binary
,
dtype
=
np
_dtype
).
reshape
(
shape
)
np_array
=
np
.
frombuffer
(
binary
,
dtype
=
dtype_info
.
numpy_view
_dtype
).
reshape
(
shape
)
if
endianness
!=
"native"
and
endianness
!=
sys_byteorder
:
np_array
=
np_array
.
byteswap
()
return
torch
.
from_numpy
(
np_array
).
view
(
torch_dtype
)
def
encode_pooling_output
(
output
:
PoolingRequestOutput
,
encoding_format
:
EncodingFormat
,
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
)
->
list
[
float
]
|
str
|
bytes
:
if
encoding_format
==
"float"
:
return
output
.
outputs
.
data
.
tolist
()
elif
encoding_format
==
"base64"
:
embedding_bytes
=
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
return
base64
.
b64encode
(
embedding_bytes
).
decode
(
"utf-8"
)
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
tensor2binary
(
output
.
outputs
.
data
,
embed_dtype
,
endianness
)
assert_never
(
encoding_format
)
@
dataclass
class
MetadataItem
:
index
:
int
embed_dtype
:
EmbedDType
endianness
:
Endianness
start
:
int
end
:
int
shape
:
tuple
[
int
,
...]
def
build_metadata_items
(
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
shape
:
tuple
[
int
,
...],
n_request
:
int
,
):
n_bytes
=
EMBED_DTYPE_TO_N_BYTES
[
embed_dtype
]
size
=
math
.
prod
(
shape
)
items
=
[
MetadataItem
(
index
=
i
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
start
=
i
*
size
*
n_bytes
,
end
=
(
i
+
1
)
*
size
*
n_bytes
,
shape
=
shape
,
)
for
i
in
range
(
n_request
)
]
return
items
def
encode_pooling_bytes
(
pooling_outputs
:
list
[
PoolingRequestOutput
],
embed_dtype
:
EmbedDType
,
endianness
:
Endianness
,
):
num_prompt_tokens
=
0
items
:
list
[
dict
[
str
,
MetadataItem
]]
=
[]
body
=
[]
offset
=
0
for
idx
,
output
in
enumerate
(
pooling_outputs
):
binary
=
tensor2binary
(
tensor
=
output
.
outputs
.
data
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
size
=
len
(
binary
)
item
=
{
"index"
:
idx
,
"embed_dtype"
:
embed_dtype
,
"endianness"
:
endianness
,
"start"
:
offset
,
"end"
:
offset
+
size
,
"shape"
:
output
.
outputs
.
data
.
shape
,
}
body
.
append
(
binary
)
items
.
append
(
item
)
prompt_token_ids
=
output
.
prompt_token_ids
num_prompt_tokens
+=
len
(
prompt_token_ids
)
offset
+=
size
usage
=
{
"prompt_tokens"
:
num_prompt_tokens
,
"total_tokens"
:
num_prompt_tokens
,
}
return
body
,
items
,
usage
def
decode_pooling_output
(
items
:
list
[
MetadataItem
],
body
:
bytes
)
->
list
[
torch
.
Tensor
]:
items
.
sort
(
key
=
lambda
x
:
x
.
index
)
tensor_list
:
list
[
torch
.
Tensor
]
=
[]
for
item
in
items
:
binary
=
body
[
item
.
start
:
item
.
end
]
tensor
=
binary2tensor
(
binary
,
item
.
shape
,
item
.
embed_dtype
,
item
.
endianness
)
tensor_list
.
append
(
tensor
)
return
tensor_list
return
torch
.
from_numpy
(
np_array
).
view
(
dtype_info
.
torch_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment