Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d917d0e
Unverified
Commit
6d917d0e
authored
Dec 14, 2024
by
Mark McLoughlin
Committed by
GitHub
Dec 14, 2024
Browse files
Enable mypy checking on V1 code (#11105)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
93abf23a
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
159 additions
and
120 deletions
+159
-120
tools/mypy.sh
tools/mypy.sh
+1
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+5
-5
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+9
-8
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+1
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+14
-9
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+6
-5
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+10
-10
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+24
-19
vllm/v1/engine/detokenizer.py
vllm/v1/engine/detokenizer.py
+2
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+2
-1
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+13
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+1
-1
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+2
-10
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+8
-7
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+4
-3
vllm/v1/request.py
vllm/v1/request.py
+1
-2
vllm/v1/utils.py
vllm/v1/utils.py
+27
-15
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+26
-16
No files found.
tools/mypy.sh
View file @
6d917d0e
...
...
@@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/v1
vllm/v1/attention/backends/flash_attn.py
View file @
6d917d0e
...
...
@@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
assert
output
is
not
None
,
"Output tensor must be provided."
if
attn_metadata
is
None
:
# Profiling run.
return
output
...
...
vllm/v1/core/kv_cache_manager.py
View file @
6d917d0e
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
...
...
@@ -263,12 +263,13 @@ class KVCacheManager:
"""
# Default to [] in case a request is freed (aborted) before alloc.
blocks
=
self
.
req_to_blocks
.
pop
(
request
.
request_id
,
[])
ordered_blocks
:
Iterable
[
KVCacheBlock
]
=
blocks
if
self
.
enable_caching
:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks
=
reversed
(
blocks
)
ordered_
blocks
=
reversed
(
blocks
)
for
block
in
blocks
:
for
block
in
ordered_
blocks
:
block
.
decr_ref
()
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
...
...
@@ -396,8 +397,7 @@ class KVCacheManager:
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
tuple
(
block_tokens
))
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
)
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
...
...
vllm/v1/core/kv_cache_utils.py
View file @
6d917d0e
"""KV-Cache Utilities."""
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
...
...
@@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same.
"""
hash_value
:
int
token_ids
:
Tuple
[
int
]
token_ids
:
Tuple
[
int
,
...
]
@
dataclass
...
...
@@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
self
.
num_free_blocks
=
len
(
blocks
)
# Initialize the doubly linked list of free blocks.
self
.
free_list_head
=
blocks
[
0
]
self
.
free_list_tail
=
blocks
[
-
1
]
self
.
free_list_head
:
Optional
[
KVCacheBlock
]
=
blocks
[
0
]
self
.
free_list_tail
:
Optional
[
KVCacheBlock
]
=
blocks
[
-
1
]
for
i
in
range
(
self
.
num_free_blocks
):
if
i
>
0
:
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
...
...
@@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Tupl
e
[
int
])
->
BlockHashType
:
curr_block_token_ids
:
Sequenc
e
[
int
])
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
...
...
@@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A
tuple
of token ids in the current
curr_block_token_ids: A
list
of token ids in the current
block. The current block is assumed to be full.
Returns:
...
...
@@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The entire tuple is used as the hash key of the block.
"""
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
tuple
(
curr_block_token_ids
)
)
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
List
[
int
])
->
List
[
BlockHashType
]:
token_ids
:
Sequence
[
int
])
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
...
...
@@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
end
=
start
+
block_size
block_token_ids
=
tuple
(
token_ids
[
start
:
end
]
)
block_token_ids
=
token_ids
[
start
:
end
]
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
break
...
...
vllm/v1/core/scheduler.py
View file @
6d917d0e
...
...
@@ -152,6 +152,7 @@ class Scheduler:
break
if
not
can_schedule
:
break
assert
new_blocks
is
not
None
# Schedule the request.
scheduled_running_reqs
.
append
(
request
)
...
...
vllm/v1/engine/__init__.py
View file @
6d917d0e
...
...
@@ -36,7 +36,7 @@ class EngineCoreRequest:
prompt
:
Optional
[
str
]
prompt_token_ids
:
List
[
int
]
mm_inputs
:
Optional
[
List
[
Optional
[
MultiModalKwargs
]]]
mm_hashes
:
Optional
[
List
[
Optional
[
str
]]
]
mm_hashes
:
Optional
[
List
[
str
]]
mm_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
sampling_params
:
SamplingParams
eos_token_id
:
Optional
[
int
]
...
...
@@ -44,10 +44,11 @@ class EngineCoreRequest:
lora_request
:
Optional
[
LoRARequest
]
class
EngineCoreOutput
(
msgspec
.
Struct
,
array_like
=
True
,
omit_defaults
=
True
,
gc
=
False
):
class
EngineCoreOutput
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
request_id
:
str
new_token_ids
:
List
[
int
]
...
...
@@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
class
EngineCoreOutputs
(
msgspec
.
Struct
,
array_like
=
True
,
omit_defaults
=
True
,
gc
=
False
):
class
EngineCoreOutputs
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
...
...
@@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
PROFILE
=
b
'
\x02
'
EngineCoreRequestUnion
=
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]]
vllm/v1/engine/async_llm.py
View file @
6d917d0e
...
...
@@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
asyncio_mode
=
True
,
)
self
.
output_handler
=
None
self
.
output_handler
:
Optional
[
asyncio
.
Task
]
=
None
def
__del__
(
self
):
self
.
shutdown
()
...
...
@@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
handler
.
cancel
()
@
classmethod
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
):
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
)
->
Type
[
Executor
]:
executor_class
:
Type
[
Executor
]
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
==
"mp"
:
...
...
@@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
logger
.
debug
(
"Called check_health."
)
async
def
start_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile
(
True
)
await
self
.
engine_core
.
profile
_async
(
True
)
async
def
stop_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile
(
False
)
await
self
.
engine_core
.
profile
_async
(
False
)
@
property
def
is_running
(
self
)
->
bool
:
...
...
@@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
@
property
def
dead_error
(
self
)
->
BaseException
:
return
Exception
return
Exception
()
# TODO: implement
# Retain V0 name for backwards compatibility.
...
...
vllm/v1/engine/core.py
View file @
6d917d0e
...
...
@@ -5,7 +5,7 @@ import threading
import
time
from
dataclasses
import
dataclass
from
multiprocessing.process
import
BaseProcess
from
typing
import
List
,
Tuple
,
Type
,
Union
from
typing
import
List
,
Tuple
,
Type
import
zmq
import
zmq.asyncio
...
...
@@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
,
EngineCoreRequestUnion
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
...
...
@@ -97,8 +97,10 @@ class EngineCore:
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
request
.
mm_inputs
=
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
)
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
,
request
.
mm_hashes
=
(
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
))
req
=
Request
.
from_engine_core_request
(
request
)
...
...
@@ -128,7 +130,7 @@ class EngineCore:
def
shutdown
(
self
):
self
.
model_executor
.
shutdown
()
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
...
...
@@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
queue
.
Queue
()
self
.
output_queue
=
queue
.
Queue
()
self
.
input_queue
:
queue
.
Queue
[
EngineCoreRequestUnion
]
=
queue
.
Queue
()
self
.
output_queue
:
queue
.
Queue
[
List
[
EngineCoreOutput
]]
=
queue
.
Queue
()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_path
,
),
daemon
=
True
).
start
()
...
...
@@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
self
.
_last_logging_time
=
now
def
_handle_client_request
(
self
,
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
def
_handle_client_request
(
self
,
request
:
EngineCoreRequestUnion
)
->
None
:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
if
isinstance
(
request
,
EngineCoreRequest
):
...
...
vllm/v1/engine/core_client.py
View file @
6d917d0e
import
atexit
import
os
from
typing
import
List
,
Un
ion
from
typing
import
List
,
Opt
ion
al
import
msgspec
import
zmq
...
...
@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from
vllm.utils
import
get_open_zmq_ipc_path
,
kill_process_tree
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
EngineCoreRequestType
,
EngineCoreRequestUnion
)
from
vllm.v1.engine.core
import
(
EngineCore
,
EngineCoreProc
,
EngineCoreProcHandle
)
from
vllm.v1.serial_utils
import
PickleEncoder
logger
=
init_logger
(
__name__
)
...
...
@@ -59,7 +60,7 @@ class EngineCoreClient:
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
raise
NotImplementedError
async
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
...
...
@@ -71,6 +72,9 @@ class EngineCoreClient:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
raise
NotImplementedError
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
...
...
@@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
def
__del__
(
self
):
self
.
shutdown
()
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
engine_core
.
profile
(
is_start
)
...
...
@@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
EngineCoreOutputs
)
# ZMQ setup.
self
.
ctx
=
(
zmq
.
asyncio
.
Context
()
if
asyncio_mode
else
zmq
.
Context
())
if
asyncio_mode
:
self
.
ctx
=
zmq
.
asyncio
.
Context
()
else
:
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Path for IPC.
ready_path
=
get_open_zmq_ipc_path
()
...
...
@@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
self
.
input_socket
.
bind
(
input_path
)
# Start EngineCore in background process.
self
.
proc_handle
:
Optional
[
EngineCoreProcHandle
]
self
.
proc_handle
=
EngineCoreProc
.
make_engine_core_process
(
*
args
,
input_path
=
input_path
,
output_path
=
output_path
,
ready_path
=
ready_path
,
input_path
=
input_path
,
# type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path
=
output_path
,
# type: ignore[misc]
ready_path
=
ready_path
,
# type: ignore[misc]
**
kwargs
,
)
atexit
.
register
(
self
.
shutdown
)
...
...
@@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
engine_core_outputs
=
self
.
decoder
.
decode
(
frame
.
buffer
).
outputs
return
engine_core_outputs
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
...
...
@@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
...
...
@@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
return
engine_core_outputs
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
...
...
@@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
if
len
(
request_ids
)
>
0
:
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile
(
self
,
is_start
=
True
)
->
None
:
async
def
profile
_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
vllm/v1/engine/detokenizer.py
View file @
6d917d0e
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
...
...
@@ -97,7 +97,7 @@ class IncrementalDetokenizer:
self
,
new_token_ids
:
List
[
int
],
finish_reason
:
Optional
[
str
],
stop_reason
:
Optional
[
str
],
stop_reason
:
Optional
[
Union
[
int
,
str
,
None
]
],
)
->
Optional
[
RequestOutput
]:
"""
Update RequestState for the request_id by:
...
...
vllm/v1/engine/llm_engine.py
View file @
6d917d0e
...
...
@@ -103,7 +103,8 @@ class LLMEngine:
multiprocess_mode
=
enable_multiprocessing
)
@
classmethod
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
):
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
)
->
Type
[
Executor
]:
executor_class
:
Type
[
Executor
]
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
==
"mp"
:
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
6d917d0e
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
PIL
from
blake3
import
blake3
...
...
@@ -42,14 +42,14 @@ class MMInputMapperClient:
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_cache
=
LRUDictCache
(
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
]
(
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
self
.
mm_debug_cache_hit_ratio_steps
=
None
self
.
mm_cache_hits
=
0
self
.
mm_cache_total
=
0
def
cache_hit_ratio
(
self
,
steps
)
->
float
:
def
cache_hit_ratio
(
self
,
steps
):
if
self
.
mm_cache_total
>
0
and
self
.
mm_cache_total
%
steps
==
0
:
logger
.
debug
(
"MMInputMapper: cache_hit_ratio = %.2f "
,
self
.
mm_cache_hits
/
self
.
mm_cache_total
)
...
...
@@ -60,7 +60,7 @@ class MMInputMapperClient:
mm_hashes
:
Optional
[
List
[
str
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
)
->
List
[
MultiModalKwargs
]:
)
->
Tuple
[
List
[
MultiModalKwargs
]
,
Optional
[
List
[
str
]]]
:
if
precomputed_mm_inputs
is
None
:
image_inputs
=
mm_data
[
"image"
]
if
not
isinstance
(
image_inputs
,
list
):
...
...
@@ -72,6 +72,7 @@ class MMInputMapperClient:
# Check if hash is enabled
use_hash
=
mm_hashes
is
not
None
if
use_hash
:
assert
mm_hashes
is
not
None
assert
num_inputs
==
len
(
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
num_inputs
,
len
(
mm_hashes
))
...
...
@@ -79,7 +80,7 @@ class MMInputMapperClient:
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes
=
[]
if
use_hash
else
None
ret_hashes
:
Optional
[
List
[
str
]]
=
[]
if
use_hash
else
None
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
for
input_id
in
range
(
num_inputs
):
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
...
...
@@ -88,6 +89,7 @@ class MMInputMapperClient:
mm_hash
=
None
mm_input
=
None
if
use_hash
:
assert
mm_hashes
is
not
None
mm_hash
=
mm_hashes
[
input_id
]
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
...
...
@@ -105,12 +107,15 @@ class MMInputMapperClient:
if
use_hash
:
# Add to cache
assert
mm_hash
is
not
None
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
else
:
self
.
mm_cache_hits
+=
1
mm_input
=
None
# Avoids sending mm_input to Server
if
use_hash
:
assert
mm_hash
is
not
None
assert
ret_hashes
is
not
None
ret_hashes
.
append
(
mm_hash
)
ret_inputs
.
append
(
mm_input
)
...
...
@@ -120,17 +125,18 @@ class MMInputMapperClient:
class
MMInputMapperServer
:
def
__init__
(
self
,
):
self
.
mm_cache
=
LRUDictCache
(
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
]
(
MM_CACHE_SIZE
)
def
process_inputs
(
self
,
mm_inputs
:
List
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
List
[
Optional
[
str
]
]
,
mm_hashes
:
List
[
str
],
)
->
List
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
full_mm_inputs
=
[]
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
if
mm_input
is
None
:
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
assert
mm_input
is
not
None
...
...
vllm/v1/engine/processor.py
View file @
6d917d0e
...
...
@@ -56,7 +56,7 @@ class Processor:
request_id
:
str
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
vllm/v1/executor/abstract.py
View file @
6d917d0e
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Tuple
from
vllm.config
import
VllmConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
...
...
@@ -28,7 +28,7 @@ class Executor(ABC):
raise
NotImplementedError
@
abstractmethod
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
raise
NotImplementedError
@
abstractmethod
...
...
@@ -38,11 +38,3 @@ class Executor(ABC):
@
abstractmethod
def
check_health
(
self
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
collective_rpc
(
self
,
method
:
str
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
[]:
raise
NotImplementedError
vllm/v1/executor/multiproc_executor.py
View file @
6d917d0e
...
...
@@ -7,7 +7,7 @@ import time
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
multiprocessing.process
import
BaseProcess
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
zmq
...
...
@@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_distributed_init_method
,
get_open_port
,
get_open_zmq_ipc_path
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
make_zmq_socket
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
class
MultiprocExecutor
:
class
MultiprocExecutor
(
Executor
)
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
# Call self.shutdown at exit to clean up
...
...
@@ -103,7 +104,7 @@ class MultiprocExecutor:
method
:
str
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
[
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
"""
Execute an RPC call on workers.
...
...
@@ -125,7 +126,7 @@ class MultiprocExecutor:
responses
=
[
None
]
*
self
.
world_size
for
w
in
self
.
workers
:
dequeue_timeout
=
timeout
-
(
time
.
monotonic
()
-
start_time
()
dequeue_timeout
=
timeout
-
(
time
.
monotonic
()
-
start_time
)
if
timeout
is
not
None
else
None
status
,
result
=
w
.
worker_response_mq
.
dequeue
(
timeout
=
dequeue_timeout
)
...
...
@@ -153,7 +154,7 @@ class MultiprocExecutor:
args
=
(
scheduler_output
,
))[
0
]
return
model_output
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
return
...
...
@@ -185,7 +186,6 @@ class MultiprocExecutor:
p
.
kill
()
self
.
_cleanup_sockets
()
self
.
workers
=
None
def
_cleanup_sockets
(
self
):
for
w
in
self
.
workers
:
...
...
@@ -200,7 +200,8 @@ class MultiprocExecutor:
# again
atexit
.
unregister
(
self
.
shutdown
)
"""Properly shut down the executor and its workers"""
if
(
hasattr
(
self
,
'workers'
)
and
self
.
workers
is
not
None
):
if
getattr
(
self
,
'shutting_down'
,
False
):
self
.
shutting_down
=
True
for
w
in
self
.
workers
:
#TODO: not sure if needed
w
.
worker_response_mq
=
None
self
.
_ensure_worker_termination
()
...
...
vllm/v1/executor/uniproc_executor.py
View file @
6d917d0e
...
...
@@ -4,13 +4,14 @@ from typing import Optional, Tuple
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu_worker
import
Worker
logger
=
init_logger
(
__name__
)
class
UniprocExecutor
:
class
UniprocExecutor
(
Executor
)
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
self
.
vllm_config
=
vllm_config
...
...
@@ -25,7 +26,7 @@ class UniprocExecutor:
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
worker
=
self
.
_create_worker
()
self
.
worker
:
Worker
=
self
.
_create_worker
()
self
.
worker
.
initialize
()
self
.
worker
.
load_model
()
...
...
@@ -75,7 +76,7 @@ class UniprocExecutor:
self
.
worker
.
profile
(
is_start
)
def
shutdown
(
self
):
self
.
worker
=
None
pass
def
check_health
(
self
)
->
None
:
# UniprocExecutor will always be healthy as long as
...
...
vllm/v1/request.py
View file @
6d917d0e
...
...
@@ -52,10 +52,9 @@ class Request:
else
:
self
.
mm_positions
=
[]
# Output of the mm input mapper (e.g., image tensors).
self
.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
if
self
.
inputs
.
multi_modal_inputs
:
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
else
:
self
.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
...
...
vllm/v1/utils.py
View file @
6d917d0e
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
Any
,
Generic
,
Iterator
,
List
,
TypeVar
,
overload
from
typing
import
(
Any
,
Generic
,
Iterator
,
List
,
Optional
,
TypeVar
,
Union
,
overload
)
import
zmq
...
...
@@ -11,7 +13,7 @@ logger = init_logger(__name__)
T
=
TypeVar
(
"T"
)
class
ConstantList
(
Generic
[
T
]):
class
ConstantList
(
Generic
[
T
]
,
Sequence
):
def
__init__
(
self
,
x
:
List
[
T
])
->
None
:
self
.
_x
=
x
...
...
@@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
def
clear
(
self
):
raise
Exception
(
"Cannot clear a constant list"
)
def
index
(
self
,
item
):
return
self
.
_x
.
index
(
item
)
def
index
(
self
,
item
:
T
,
start
:
int
=
0
,
stop
:
Optional
[
int
]
=
None
)
->
int
:
return
self
.
_x
.
index
(
item
,
start
,
stop
if
stop
is
not
None
else
len
(
self
.
_x
))
@
overload
def
__getitem__
(
self
,
item
)
->
T
:
def
__getitem__
(
self
,
item
:
int
)
->
T
:
...
@
overload
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
List
[
T
]:
...
def
__getitem__
(
self
,
item
)
:
def
__getitem__
(
self
,
item
:
Union
[
int
,
slice
])
->
Union
[
T
,
List
[
T
]]
:
return
self
.
_x
[
item
]
@
overload
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
:
int
,
value
:
T
):
...
@
overload
def
__setitem__
(
self
,
s
:
slice
,
value
,
/
):
def
__setitem__
(
self
,
s
:
slice
,
value
:
T
,
/
):
...
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
:
Union
[
int
,
slice
],
value
:
Union
[
T
,
List
[
T
]]
):
raise
Exception
(
"Cannot set item in a constant list"
)
def
__delitem__
(
self
,
item
):
...
...
@@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
@
contextmanager
def
make_zmq_socket
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
def
make_zmq_socket
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
try
:
socket
=
ctx
.
socket
(
type
)
...
...
@@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
ctx
.
destroy
(
linger
=
0
)
class
LRUDictCache
:
K
=
TypeVar
(
'K'
)
V
=
TypeVar
(
'V'
)
class
LRUDictCache
(
Generic
[
K
,
V
]):
def
__init__
(
self
,
size
:
int
):
self
.
cache
=
OrderedDict
()
self
.
cache
:
OrderedDict
[
K
,
V
]
=
OrderedDict
()
self
.
size
=
size
def
get
(
self
,
key
,
default
=
None
):
def
get
(
self
,
key
:
K
,
default
=
None
)
->
V
:
if
key
not
in
self
.
cache
:
return
default
self
.
cache
.
move_to_end
(
key
)
return
self
.
cache
[
key
]
def
put
(
self
,
key
,
value
):
def
put
(
self
,
key
:
K
,
value
:
V
):
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
if
len
(
self
.
cache
)
>
self
.
size
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
6d917d0e
...
...
@@ -215,6 +215,7 @@ class InputBatch:
# Swap the states.
req_id
=
self
.
req_ids
[
last_req_index
]
assert
req_id
is
not
None
self
.
req_ids
[
empty_index
]
=
req_id
self
.
req_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6d917d0e
import
gc
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
cast
import
numpy
as
np
import
torch
...
...
@@ -193,9 +193,9 @@ class GPUModelRunner:
req_ids_to_add
:
List
[
str
]
=
[]
# Add new requests to the cached states.
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
req_data
.
req_id
sampling_params
=
req_data
.
sampling_params
for
new_
req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_
req_data
.
req_id
sampling_params
=
new_
req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
...
...
@@ -204,25 +204,25 @@ class GPUModelRunner:
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
prompt_token_ids
=
req_data
.
prompt_token_ids
,
prompt
=
req_data
.
prompt
,
mm_inputs
=
req_data
.
mm_inputs
,
mm_positions
=
req_data
.
mm_positions
,
prompt_token_ids
=
new_
req_data
.
prompt_token_ids
,
prompt
=
new_
req_data
.
prompt
,
mm_inputs
=
new_
req_data
.
mm_inputs
,
mm_positions
=
new_
req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
generator
=
generator
,
block_ids
=
req_data
.
block_ids
,
num_computed_tokens
=
req_data
.
num_computed_tokens
,
block_ids
=
new_
req_data
.
block_ids
,
num_computed_tokens
=
new_
req_data
.
num_computed_tokens
,
output_token_ids
=
[],
)
req_ids_to_add
.
append
(
req_id
)
# Update the cached states of the resumed requests.
for
req_data
in
scheduler_output
.
scheduled_resumed_reqs
:
req_id
=
req_data
.
req_id
for
res_
req_data
in
scheduler_output
.
scheduled_resumed_reqs
:
req_id
=
res_
req_data
.
req_id
req_state
=
self
.
requests
[
req_id
]
req_state
.
block_ids
=
req_data
.
block_ids
req_state
.
num_computed_tokens
=
req_data
.
num_computed_tokens
req_state
.
block_ids
=
res_
req_data
.
block_ids
req_state
.
num_computed_tokens
=
res_
req_data
.
num_computed_tokens
req_ids_to_add
.
append
(
req_id
)
# Add the new or resumed requests to the persistent batch.
...
...
@@ -259,6 +259,7 @@ class GPUModelRunner:
num_scheduled_tokens
=
[]
max_num_scheduled_tokens
=
0
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
.
append
(
num_tokens
)
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
...
...
@@ -373,7 +374,7 @@ class GPUModelRunner:
# Batch the multi-modal inputs.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
req_input_ids
:
List
[
Tuple
[
int
,
int
]]
=
[]
req_input_ids
:
List
[
Tuple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
...
...
@@ -406,6 +407,7 @@ class GPUModelRunner:
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
num_reqs
=
self
.
input_batch
.
num_reqs
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
...
@@ -514,6 +516,7 @@ class GPUModelRunner:
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
[:
num_reqs
]):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
...
...
@@ -539,8 +542,15 @@ class GPUModelRunner:
logprobs
=
None
else
:
logprobs
=
sampler_output
.
logprobs
.
cpu
()
# num_reqs entries should be non-None
assert
all
(
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
req_ids
=
cast
(
List
[
str
],
self
.
input_batch
.
req_ids
[:
num_reqs
])
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
[:
num_reqs
]
,
req_ids
=
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
sampled_token_ids
,
logprob_token_ids_cpu
=
logprob_token_ids
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment