Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51f86bf4
Unverified
Commit
51f86bf4
authored
Aug 28, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 27, 2024
Browse files
[mypy][CI/Build] Fix mypy errors (#7929)
parent
c166e7e4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
9 deletions
+24
-9
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+5
-0
vllm/assets/audio.py
vllm/assets/audio.py
+3
-1
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+3
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+12
-5
vllm/sequence.py
vllm/sequence.py
+1
-1
No files found.
tests/samplers/test_sampler.py
View file @
51f86bf4
...
...
@@ -418,6 +418,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
prompt_len
=
seq_data
.
get_prompt_len
()
seq_lens
.
append
(
prompt_len
)
assert
sgm
.
sampling_params
is
not
None
if
sgm
.
sampling_params
.
prompt_logprobs
:
# with prompt_logprobs each token in the prompt has a row in
# logits
...
...
@@ -533,6 +534,8 @@ def test_sampler_mixed(seed: int, device: str):
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
zip
(
sampler_output
,
seq_group_metadata_list
)):
assert
metadata
.
sampling_params
is
not
None
if
metadata
.
sampling_params
.
use_beam_search
:
continue
...
...
@@ -550,6 +553,8 @@ def test_sampler_mixed(seed: int, device: str):
assert
expected_tokens_item
is
not
None
for
n
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
metadata
.
sampling_params
is
not
None
if
(
metadata
.
sampling_params
.
temperature
==
0
or
metadata
.
sampling_params
.
seed
is
not
None
):
# Ensure exact matches for greedy or random with seed
...
...
vllm/assets/audio.py
View file @
51f86bf4
...
...
@@ -19,7 +19,9 @@ class AudioAsset:
audio_path
=
get_vllm_public_assets
(
filename
=
f
"
{
self
.
name
}
.ogg"
,
s3_prefix
=
ASSET_DIR
)
return
librosa
.
load
(
audio_path
,
sr
=
None
)
y
,
sr
=
librosa
.
load
(
audio_path
,
sr
=
None
)
assert
isinstance
(
sr
,
int
)
return
y
,
sr
@
property
def
url
(
self
)
->
str
:
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
51f86bf4
...
...
@@ -101,6 +101,7 @@ class AsyncEngineRPCClient:
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
assert
isinstance
(
socket_limit
,
int
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
...
...
@@ -141,8 +142,8 @@ class AsyncEngineRPCClient:
poller
.
register
(
socket_from
,
zmq
.
constants
.
POLLIN
)
poller
.
register
(
socket_to
,
zmq
.
constants
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
events
=
dict
(
events
)
events
_lst
=
await
poller
.
poll
()
events
=
dict
(
events
_lst
)
if
socket_from
in
events
:
identity
,
msg
=
await
socket_from
.
recv_multipart
()
await
socket_to
.
send_multipart
([
identity
,
msg
])
...
...
vllm/multimodal/base.py
View file @
51f86bf4
...
...
@@ -14,7 +14,7 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
json_map_leaves
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
logger
=
init_logger
(
__name__
)
...
...
@@ -54,13 +54,14 @@ class MultiModalInputs(_MultiModalInputsBase):
return
nested_tensors
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
any
(
isinstance
(
t
,
list
)
for
t
in
stacked
):
if
is_list_of
(
stacked
,
list
):
# Do not stack nested lists
return
stacked
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
# The tensors have incompatible shapes and can't be stacked.
return
tensors_
return
stacked
return
torch
.
stack
(
tensors_
)
...
...
@@ -101,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
return
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
batched_inputs
)
json_inputs
=
cast
(
JSONTree
[
torch
.
Tensor
],
batched_inputs
)
json_mapped
=
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
json_inputs
,
)
return
cast
(
BatchedTensorInputs
,
json_mapped
)
_T
=
TypeVar
(
"_T"
)
...
...
vllm/sequence.py
View file @
51f86bf4
...
...
@@ -883,7 +883,7 @@ class SequenceGroupMetadata(
request_id
:
str
is_prompt
:
bool
seq_data
:
Dict
[
int
,
SequenceData
]
sampling_params
:
SamplingParams
sampling_params
:
Optional
[
SamplingParams
]
block_tables
:
Dict
[
int
,
List
[
int
]]
do_sample
:
bool
=
True
pooling_params
:
Optional
[
PoolingParams
]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment