Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
97234be0
Unverified
Commit
97234be0
authored
Jul 23, 2024
by
Cyrus Leung
Committed by
GitHub
Jul 22, 2024
Browse files
[Misc] Manage HTTP connections in one place (#6600)
parent
c051bfe4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
215 additions
and
85 deletions
+215
-85
tests/conftest.py
tests/conftest.py
+8
-0
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+4
-6
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+5
-5
vllm/assets/image.py
vllm/assets/image.py
+6
-7
vllm/connections.py
vllm/connections.py
+167
-0
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+22
-66
vllm/usage/usage_lib.py
vllm/usage/usage_lib.py
+3
-1
No files found.
tests/conftest.py
View file @
97234be0
...
...
@@ -16,6 +16,7 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
TokenizerPoolConfig
from
vllm.connections
import
global_http_connection
from
vllm.distributed
import
(
destroy_distributed_environment
,
destroy_model_parallel
)
from
vllm.inputs
import
TextPrompt
...
...
@@ -74,6 +75,13 @@ IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
@
pytest
.
fixture
(
autouse
=
True
)
def
init_test_http_connection
():
# pytest_asyncio may use a different event loop per test
# so we need to make sure the async client is created anew
global_http_connection
.
reuse_client
=
False
def
cleanup
():
destroy_model_parallel
()
destroy_distributed_environment
()
...
...
tests/entrypoints/openai/test_vision.py
View file @
97234be0
...
...
@@ -2,9 +2,8 @@ from typing import Dict, List
import
openai
import
pytest
import
pytest_asyncio
from
vllm.multimodal.utils
import
ImageFetchAiohttp
,
encode_image_base64
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
...
...
@@ -42,11 +41,10 @@ def client(server):
return
server
.
get_async_client
()
@
pytest
_asyncio
.
fixture
(
scope
=
"session"
)
async
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
Dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
await
ImageFetchAiohttp
.
fetch_image
(
image_url
))
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
}
...
...
tests/multimodal/test_utils.py
View file @
97234be0
...
...
@@ -7,7 +7,7 @@ import numpy as np
import
pytest
from
PIL
import
Image
from
vllm.multimodal.utils
import
ImageFetchAiohttp
,
fetch_image
from
vllm.multimodal.utils
import
async_fetch_image
,
fetch_image
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
...
...
@@ -37,15 +37,15 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
return
(
np
.
asarray
(
a
)
==
np
.
asarray
(
b
.
convert
(
a
.
mode
))).
all
()
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
async
def
test_fetch_image_http
(
image_url
:
str
):
image_sync
=
fetch_image
(
image_url
)
image_async
=
await
ImageFetchAiohttp
.
fetch_image
(
image_url
)
image_async
=
await
async_
fetch_image
(
image_url
)
assert
_image_equals
(
image_sync
,
image_async
)
@
pytest
.
mark
.
asyncio
(
scope
=
"module"
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_URLS
)
@
pytest
.
mark
.
parametrize
(
"suffix"
,
get_supported_suffixes
())
async
def
test_fetch_image_base64
(
url_images
:
Dict
[
str
,
Image
.
Image
],
...
...
@@ -78,5 +78,5 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
else
:
pass
# Lossy format; only check that image can be opened
data_image_async
=
await
ImageFetchAiohttp
.
fetch_image
(
data_url
)
data_image_async
=
await
async_
fetch_image
(
data_url
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
vllm/assets/image.py
View file @
97234be0
import
shutil
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
Literal
import
requests
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
.base
import
get_cache_dir
...
...
@@ -22,11 +23,9 @@ def get_air_example_data_2_asset(filename: str) -> Image.Image:
if
not
image_path
.
exists
():
base_url
=
"https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava"
with
requests
.
get
(
f
"
{
base_url
}
/
{
filename
}
"
,
stream
=
True
)
as
response
:
response
.
raise_for_status
()
with
image_path
.
open
(
"wb"
)
as
f
:
shutil
.
copyfileobj
(
response
.
raw
,
f
)
global_http_connection
.
download_file
(
f
"
{
base_url
}
/
{
filename
}
"
,
image_path
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
return
Image
.
open
(
image_path
)
...
...
vllm/connections.py
0 → 100644
View file @
97234be0
from
pathlib
import
Path
from
typing
import
Mapping
,
Optional
from
urllib.parse
import
urlparse
import
aiohttp
import
requests
from
vllm.version
import
__version__
as
VLLM_VERSION
class
HTTPConnection
:
"""Helper class to send HTTP requests."""
def
__init__
(
self
,
*
,
reuse_client
:
bool
=
True
)
->
None
:
super
().
__init__
()
self
.
reuse_client
=
reuse_client
self
.
_sync_client
:
Optional
[
requests
.
Session
]
=
None
self
.
_async_client
:
Optional
[
aiohttp
.
ClientSession
]
=
None
def
get_sync_client
(
self
)
->
requests
.
Session
:
if
self
.
_sync_client
is
None
or
not
self
.
reuse_client
:
self
.
_sync_client
=
requests
.
Session
()
return
self
.
_sync_client
# NOTE: We intentionally use an async function even though it is not
# required, so that the client is only accessible inside async event loop
async
def
get_async_client
(
self
)
->
aiohttp
.
ClientSession
:
if
self
.
_async_client
is
None
or
not
self
.
reuse_client
:
self
.
_async_client
=
aiohttp
.
ClientSession
()
return
self
.
_async_client
def
_validate_http_url
(
self
,
url
:
str
):
parsed_url
=
urlparse
(
url
)
if
parsed_url
.
scheme
not
in
(
"http"
,
"https"
):
raise
ValueError
(
"Invalid HTTP URL: A valid HTTP URL "
"must have scheme 'http' or 'https'."
)
def
_headers
(
self
,
**
extras
:
str
)
->
Mapping
[
str
,
str
]:
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
,
**
extras
}
def
get_response
(
self
,
url
:
str
,
*
,
stream
:
bool
=
False
,
timeout
:
Optional
[
float
]
=
None
,
extra_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
):
self
.
_validate_http_url
(
url
)
client
=
self
.
get_sync_client
()
extra_headers
=
extra_headers
or
{}
return
client
.
get
(
url
,
headers
=
self
.
_headers
(
**
extra_headers
),
stream
=
stream
,
timeout
=
timeout
)
async
def
get_async_response
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
,
extra_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
):
self
.
_validate_http_url
(
url
)
client
=
await
self
.
get_async_client
()
extra_headers
=
extra_headers
or
{}
return
client
.
get
(
url
,
headers
=
self
.
_headers
(
**
extra_headers
),
timeout
=
timeout
)
def
get_bytes
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
)
->
bytes
:
with
self
.
get_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
r
.
content
async
def
async_get_bytes
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
,
)
->
bytes
:
async
with
await
self
.
get_async_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
await
r
.
read
()
def
get_text
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
)
->
str
:
with
self
.
get_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
r
.
text
async
def
async_get_text
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
,
)
->
str
:
async
with
await
self
.
get_async_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
await
r
.
text
()
def
get_json
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
)
->
str
:
with
self
.
get_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
r
.
json
()
async
def
async_get_json
(
self
,
url
:
str
,
*
,
timeout
:
Optional
[
float
]
=
None
,
)
->
str
:
async
with
await
self
.
get_async_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
return
await
r
.
json
()
def
download_file
(
self
,
url
:
str
,
save_path
:
Path
,
*
,
timeout
:
Optional
[
float
]
=
None
,
chunk_size
:
int
=
128
,
)
->
Path
:
with
self
.
get_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
with
save_path
.
open
(
"wb"
)
as
f
:
for
chunk
in
r
.
iter_content
(
chunk_size
):
f
.
write
(
chunk
)
return
save_path
async
def
async_download_file
(
self
,
url
:
str
,
save_path
:
Path
,
*
,
timeout
:
Optional
[
float
]
=
None
,
chunk_size
:
int
=
128
,
)
->
Path
:
async
with
await
self
.
get_async_response
(
url
,
timeout
=
timeout
)
as
r
:
r
.
raise_for_status
()
with
save_path
.
open
(
"wb"
)
as
f
:
async
for
chunk
in
r
.
content
.
iter_chunked
(
chunk_size
):
f
.
write
(
chunk
)
return
save_path
global_http_connection
=
HTTPConnection
()
"""The global :class:`HTTPConnection` instance used by vLLM."""
vllm/multimodal/utils.py
View file @
97234be0
import
base64
from
io
import
BytesIO
from
typing
import
Optional
,
Union
from
urllib.parse
import
urlparse
from
typing
import
Union
import
aiohttp
import
requests
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.envs
import
VLLM_IMAGE_FETCH_TIMEOUT
from
vllm.multimodal.base
import
MultiModalDataDict
from
vllm.version
import
__version__
as
VLLM_VERSION
def
_validate_remote_url
(
url
:
str
,
*
,
name
:
str
):
parsed_url
=
urlparse
(
url
)
if
parsed_url
.
scheme
not
in
[
"http"
,
"https"
]:
raise
ValueError
(
f
"Invalid '
{
name
}
': A valid '
{
name
}
' "
"must have scheme 'http' or 'https'."
)
def
_get_request_headers
():
return
{
"User-Agent"
:
f
"vLLM/
{
VLLM_VERSION
}
"
}
def
_load_image_from_bytes
(
b
:
bytes
):
...
...
@@ -42,13 +28,8 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
By default, the image is converted into RGB format.
"""
if
image_url
.
startswith
(
'http'
):
_validate_remote_url
(
image_url
,
name
=
"image_url"
)
headers
=
_get_request_headers
()
with
requests
.
get
(
url
=
image_url
,
headers
=
headers
)
as
response
:
response
.
raise_for_status
()
image_raw
=
response
.
content
image_raw
=
global_http_connection
.
get_bytes
(
image_url
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
image
=
_load_image_from_bytes
(
image_raw
)
elif
image_url
.
startswith
(
'data:image'
):
...
...
@@ -60,55 +41,30 @@ def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
return
image
.
convert
(
image_mode
)
class
ImageFetchAiohttp
:
aiohttp_client
:
Optional
[
aiohttp
.
ClientSession
]
=
None
@
classmethod
def
get_aiohttp_client
(
cls
)
->
aiohttp
.
ClientSession
:
if
cls
.
aiohttp_client
is
None
:
timeout
=
aiohttp
.
ClientTimeout
(
total
=
VLLM_IMAGE_FETCH_TIMEOUT
)
connector
=
aiohttp
.
TCPConnector
()
cls
.
aiohttp_client
=
aiohttp
.
ClientSession
(
timeout
=
timeout
,
connector
=
connector
)
return
cls
.
aiohttp_client
@
classmethod
async
def
fetch_image
(
cls
,
image_url
:
str
,
*
,
image_mode
:
str
=
"RGB"
,
)
->
Image
.
Image
:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if
image_url
.
startswith
(
'http'
):
_validate_remote_url
(
image_url
,
name
=
"image_url"
)
client
=
cls
.
get_aiohttp_client
()
headers
=
_get_request_headers
()
async
def
async_fetch_image
(
image_url
:
str
,
*
,
image_mode
:
str
=
"RGB"
)
->
Image
.
Image
:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
async
with
client
.
get
(
url
=
image_url
,
headers
=
headers
)
as
response
:
response
.
raise_for_status
()
image_raw
=
await
response
.
read
()
image
=
_load_image_from_bytes
(
image_raw
)
By default, the image is converted into RGB format.
"""
if
image_url
.
startswith
(
'http'
):
image_raw
=
await
global_http_connection
.
async_get_bytes
(
image_url
,
timeout
=
VLLM_IMAGE_FETCH_TIMEOUT
)
image
=
_load_image_from_bytes
(
image_raw
)
elif
image_url
.
startswith
(
'data:image'
):
image
=
_load_image_from_data_url
(
image_url
)
else
:
raise
ValueError
(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'."
)
elif
image_url
.
startswith
(
'data:image'
):
image
=
_load_image_from_data_url
(
image_url
)
else
:
raise
ValueError
(
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'."
)
return
image
.
convert
(
image_mode
)
return
image
.
convert
(
image_mode
)
async
def
async_get_and_parse_image
(
image_url
:
str
)
->
MultiModalDataDict
:
image
=
await
ImageFetchAiohttp
.
fetch_image
(
image_url
)
image
=
await
async_
fetch_image
(
image_url
)
return
{
"image"
:
image
}
...
...
vllm/usage/usage_lib.py
View file @
97234be0
...
...
@@ -16,6 +16,7 @@ import requests
import
torch
import
vllm.envs
as
envs
from
vllm.connections
import
global_http_connection
from
vllm.version
import
__version__
as
VLLM_VERSION
_config_home
=
envs
.
VLLM_CONFIG_ROOT
...
...
@@ -204,7 +205,8 @@ class UsageMessage:
def
_send_to_server
(
self
,
data
):
try
:
requests
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
global_http_client
=
global_http_connection
.
get_sync_client
()
global_http_client
.
post
(
_USAGE_STATS_SERVER
,
json
=
data
)
except
requests
.
exceptions
.
RequestException
:
# silently ignore unless we are using debug log
logging
.
debug
(
"Failed to send usage data to server"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment