Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d917d0e
Unverified
Commit
6d917d0e
authored
Dec 14, 2024
by
Mark McLoughlin
Committed by
GitHub
Dec 14, 2024
Browse files
Enable mypy checking on V1 code (#11105)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
93abf23a
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
159 additions
and
120 deletions
+159
-120
tools/mypy.sh
tools/mypy.sh
+1
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+5
-5
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+9
-8
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+1
-0
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+14
-9
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+6
-5
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+10
-10
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+24
-19
vllm/v1/engine/detokenizer.py
vllm/v1/engine/detokenizer.py
+2
-2
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+2
-1
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+13
-7
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+1
-1
vllm/v1/executor/abstract.py
vllm/v1/executor/abstract.py
+2
-10
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+8
-7
vllm/v1/executor/uniproc_executor.py
vllm/v1/executor/uniproc_executor.py
+4
-3
vllm/v1/request.py
vllm/v1/request.py
+1
-2
vllm/v1/utils.py
vllm/v1/utils.py
+27
-15
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+1
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+26
-16
No files found.
tools/mypy.sh
View file @
6d917d0e
...
@@ -29,3 +29,4 @@ run_mypy vllm/plugins
...
@@ -29,3 +29,4 @@ run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/worker
run_mypy vllm/v1
vllm/v1/attention/backends/flash_attn.py
View file @
6d917d0e
...
@@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -135,6 +135,8 @@ class FlashAttentionImpl(AttentionImpl):
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
"key/v_scale is not supported in FlashAttention."
)
assert
output
is
not
None
,
"Output tensor must be provided."
if
attn_metadata
is
None
:
if
attn_metadata
is
None
:
# Profiling run.
# Profiling run.
return
output
return
output
...
...
vllm/v1/core/kv_cache_manager.py
View file @
6d917d0e
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
...
@@ -263,12 +263,13 @@ class KVCacheManager:
...
@@ -263,12 +263,13 @@ class KVCacheManager:
"""
"""
# Default to [] in case a request is freed (aborted) before alloc.
# Default to [] in case a request is freed (aborted) before alloc.
blocks
=
self
.
req_to_blocks
.
pop
(
request
.
request_id
,
[])
blocks
=
self
.
req_to_blocks
.
pop
(
request
.
request_id
,
[])
ordered_blocks
:
Iterable
[
KVCacheBlock
]
=
blocks
if
self
.
enable_caching
:
if
self
.
enable_caching
:
# Free blocks in reverse order so that the tail blocks are
# Free blocks in reverse order so that the tail blocks are
# freed first.
# freed first.
blocks
=
reversed
(
blocks
)
ordered_
blocks
=
reversed
(
blocks
)
for
block
in
blocks
:
for
block
in
ordered_
blocks
:
block
.
decr_ref
()
block
.
decr_ref
()
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
self
.
free_block_queue
.
append
(
block
)
...
@@ -396,8 +397,7 @@ class KVCacheManager:
...
@@ -396,8 +397,7 @@ class KVCacheManager:
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Compute the hash of the current block.
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
)
tuple
(
block_tokens
))
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
blk
.
block_hash
=
block_hash
...
...
vllm/v1/core/kv_cache_utils.py
View file @
6d917d0e
"""KV-Cache Utilities."""
"""KV-Cache Utilities."""
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
...
@@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
...
@@ -13,7 +14,7 @@ class BlockHashType(NamedTuple):
collision happens when the hash value is the same.
collision happens when the hash value is the same.
"""
"""
hash_value
:
int
hash_value
:
int
token_ids
:
Tuple
[
int
]
token_ids
:
Tuple
[
int
,
...
]
@
dataclass
@
dataclass
...
@@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
...
@@ -79,8 +80,8 @@ class FreeKVCacheBlockQueue:
self
.
num_free_blocks
=
len
(
blocks
)
self
.
num_free_blocks
=
len
(
blocks
)
# Initialize the doubly linked list of free blocks.
# Initialize the doubly linked list of free blocks.
self
.
free_list_head
=
blocks
[
0
]
self
.
free_list_head
:
Optional
[
KVCacheBlock
]
=
blocks
[
0
]
self
.
free_list_tail
=
blocks
[
-
1
]
self
.
free_list_tail
:
Optional
[
KVCacheBlock
]
=
blocks
[
-
1
]
for
i
in
range
(
self
.
num_free_blocks
):
for
i
in
range
(
self
.
num_free_blocks
):
if
i
>
0
:
if
i
>
0
:
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
blocks
[
i
].
prev_free_block
=
blocks
[
i
-
1
]
...
@@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
...
@@ -159,7 +160,7 @@ class FreeKVCacheBlockQueue:
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Tupl
e
[
int
])
->
BlockHashType
:
curr_block_token_ids
:
Sequenc
e
[
int
])
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
prefix caching. We use LRU cache for this function to avoid recomputing
...
@@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
...
@@ -171,7 +172,7 @@ def hash_block_tokens(parent_block_hash: Optional[int],
Args:
Args:
parent_block_hash: The hash of the parent block. None
parent_block_hash: The hash of the parent block. None
if this is the first block.
if this is the first block.
curr_block_token_ids: A
tuple
of token ids in the current
curr_block_token_ids: A
list
of token ids in the current
block. The current block is assumed to be full.
block. The current block is assumed to be full.
Returns:
Returns:
...
@@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
...
@@ -179,11 +180,11 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The entire tuple is used as the hash key of the block.
The entire tuple is used as the hash key of the block.
"""
"""
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
tuple
(
curr_block_token_ids
)
)
def
hash_request_tokens
(
block_size
:
int
,
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
List
[
int
])
->
List
[
BlockHashType
]:
token_ids
:
Sequence
[
int
])
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
token IDs. The hash value is used for prefix caching.
...
@@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
...
@@ -198,7 +199,7 @@ def hash_request_tokens(block_size: int,
parent_block_hash_value
=
None
parent_block_hash_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
end
=
start
+
block_size
end
=
start
+
block_size
block_token_ids
=
tuple
(
token_ids
[
start
:
end
]
)
block_token_ids
=
token_ids
[
start
:
end
]
# Do not hash the block if it is not full.
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
if
len
(
block_token_ids
)
<
block_size
:
break
break
...
...
vllm/v1/core/scheduler.py
View file @
6d917d0e
...
@@ -152,6 +152,7 @@ class Scheduler:
...
@@ -152,6 +152,7 @@ class Scheduler:
break
break
if
not
can_schedule
:
if
not
can_schedule
:
break
break
assert
new_blocks
is
not
None
# Schedule the request.
# Schedule the request.
scheduled_running_reqs
.
append
(
request
)
scheduled_running_reqs
.
append
(
request
)
...
...
vllm/v1/engine/__init__.py
View file @
6d917d0e
...
@@ -36,7 +36,7 @@ class EngineCoreRequest:
...
@@ -36,7 +36,7 @@ class EngineCoreRequest:
prompt
:
Optional
[
str
]
prompt
:
Optional
[
str
]
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
mm_inputs
:
Optional
[
List
[
Optional
[
MultiModalKwargs
]]]
mm_inputs
:
Optional
[
List
[
Optional
[
MultiModalKwargs
]]]
mm_hashes
:
Optional
[
List
[
Optional
[
str
]]
]
mm_hashes
:
Optional
[
List
[
str
]]
mm_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
mm_placeholders
:
Optional
[
MultiModalPlaceholderDict
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
eos_token_id
:
Optional
[
int
]
eos_token_id
:
Optional
[
int
]
...
@@ -44,10 +44,11 @@ class EngineCoreRequest:
...
@@ -44,10 +44,11 @@ class EngineCoreRequest:
lora_request
:
Optional
[
LoRARequest
]
lora_request
:
Optional
[
LoRARequest
]
class
EngineCoreOutput
(
msgspec
.
Struct
,
class
EngineCoreOutput
(
array_like
=
True
,
msgspec
.
Struct
,
omit_defaults
=
True
,
array_like
=
True
,
# type: ignore[call-arg]
gc
=
False
):
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
request_id
:
str
request_id
:
str
new_token_ids
:
List
[
int
]
new_token_ids
:
List
[
int
]
...
@@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
...
@@ -56,10 +57,11 @@ class EngineCoreOutput(msgspec.Struct,
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
class
EngineCoreOutputs
(
msgspec
.
Struct
,
class
EngineCoreOutputs
(
array_like
=
True
,
msgspec
.
Struct
,
omit_defaults
=
True
,
array_like
=
True
,
# type: ignore[call-arg]
gc
=
False
):
omit_defaults
=
True
,
# type: ignore[call-arg]
gc
=
False
):
# type: ignore[call-arg]
#NOTE(Nick): We could consider ways to make this more compact,
#NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout and using an int enum for finish/stop reason
# e.g. columnwise layout and using an int enum for finish/stop reason
...
@@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
...
@@ -81,3 +83,6 @@ class EngineCoreRequestType(enum.Enum):
ADD
=
b
'
\x00
'
ADD
=
b
'
\x00
'
ABORT
=
b
'
\x01
'
ABORT
=
b
'
\x01
'
PROFILE
=
b
'
\x02
'
PROFILE
=
b
'
\x02
'
EngineCoreRequestUnion
=
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]]
vllm/v1/engine/async_llm.py
View file @
6d917d0e
...
@@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
...
@@ -81,7 +81,7 @@ class AsyncLLM(EngineClient):
asyncio_mode
=
True
,
asyncio_mode
=
True
,
)
)
self
.
output_handler
=
None
self
.
output_handler
:
Optional
[
asyncio
.
Task
]
=
None
def
__del__
(
self
):
def
__del__
(
self
):
self
.
shutdown
()
self
.
shutdown
()
...
@@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
...
@@ -126,7 +126,8 @@ class AsyncLLM(EngineClient):
handler
.
cancel
()
handler
.
cancel
()
@
classmethod
@
classmethod
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
):
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
)
->
Type
[
Executor
]:
executor_class
:
Type
[
Executor
]
distributed_executor_backend
=
(
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
vllm_config
.
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
==
"mp"
:
if
distributed_executor_backend
==
"mp"
:
...
@@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
...
@@ -361,10 +362,10 @@ class AsyncLLM(EngineClient):
logger
.
debug
(
"Called check_health."
)
logger
.
debug
(
"Called check_health."
)
async
def
start_profile
(
self
)
->
None
:
async
def
start_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile
(
True
)
await
self
.
engine_core
.
profile
_async
(
True
)
async
def
stop_profile
(
self
)
->
None
:
async
def
stop_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile
(
False
)
await
self
.
engine_core
.
profile
_async
(
False
)
@
property
@
property
def
is_running
(
self
)
->
bool
:
def
is_running
(
self
)
->
bool
:
...
@@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
...
@@ -380,7 +381,7 @@ class AsyncLLM(EngineClient):
@
property
@
property
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
return
Exception
return
Exception
()
# TODO: implement
# Retain V0 name for backwards compatibility.
# Retain V0 name for backwards compatibility.
...
...
vllm/v1/engine/core.py
View file @
6d917d0e
...
@@ -5,7 +5,7 @@ import threading
...
@@ -5,7 +5,7 @@ import threading
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.process
import
BaseProcess
from
typing
import
List
,
Tuple
,
Type
,
Union
from
typing
import
List
,
Tuple
,
Type
import
zmq
import
zmq
import
zmq.asyncio
import
zmq.asyncio
...
@@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
...
@@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
,
EngineCoreRequestUnion
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperServer
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
...
@@ -97,8 +97,10 @@ class EngineCore:
...
@@ -97,8 +97,10 @@ class EngineCore:
# Note that the cache here is mirrored with the client side of the
# Note that the cache here is mirrored with the client side of the
# MM mapper, so anything that has a hash must have a HIT cache
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
# entry here as well.
request
.
mm_inputs
=
self
.
mm_input_mapper_server
.
process_inputs
(
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
,
request
.
mm_hashes
)
request
.
mm_inputs
,
request
.
mm_hashes
=
(
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
))
req
=
Request
.
from_engine_core_request
(
request
)
req
=
Request
.
from_engine_core_request
(
request
)
...
@@ -128,7 +130,7 @@ class EngineCore:
...
@@ -128,7 +130,7 @@ class EngineCore:
def
shutdown
(
self
):
def
shutdown
(
self
):
self
.
model_executor
.
shutdown
()
self
.
model_executor
.
shutdown
()
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
model_executor
.
profile
(
is_start
)
self
.
model_executor
.
profile
(
is_start
)
...
@@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
...
@@ -161,8 +163,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# and to overlap some serialization/deserialization with the
# model forward pass.
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self
.
input_queue
=
queue
.
Queue
()
self
.
input_queue
:
queue
.
Queue
[
EngineCoreRequestUnion
]
=
queue
.
Queue
()
self
.
output_queue
=
queue
.
Queue
()
self
.
output_queue
:
queue
.
Queue
[
List
[
EngineCoreOutput
]]
=
queue
.
Queue
()
threading
.
Thread
(
target
=
self
.
process_input_socket
,
threading
.
Thread
(
target
=
self
.
process_input_socket
,
args
=
(
input_path
,
),
args
=
(
input_path
,
),
daemon
=
True
).
start
()
daemon
=
True
).
start
()
...
@@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
...
@@ -318,9 +320,7 @@ class EngineCoreProc(EngineCore):
self
.
_last_logging_time
=
now
self
.
_last_logging_time
=
now
def
_handle_client_request
(
def
_handle_client_request
(
self
,
request
:
EngineCoreRequestUnion
)
->
None
:
self
,
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
if
isinstance
(
request
,
EngineCoreRequest
):
if
isinstance
(
request
,
EngineCoreRequest
):
...
...
vllm/v1/engine/core_client.py
View file @
6d917d0e
import
atexit
import
atexit
import
os
import
os
from
typing
import
List
,
Un
ion
from
typing
import
List
,
Opt
ion
al
import
msgspec
import
msgspec
import
zmq
import
zmq
...
@@ -10,8 +10,9 @@ from vllm.logger import init_logger
...
@@ -10,8 +10,9 @@ from vllm.logger import init_logger
from
vllm.utils
import
get_open_zmq_ipc_path
,
kill_process_tree
from
vllm.utils
import
get_open_zmq_ipc_path
,
kill_process_tree
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreProfile
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
,
EngineCoreRequestUnion
)
from
vllm.v1.engine.core
import
EngineCore
,
EngineCoreProc
from
vllm.v1.engine.core
import
(
EngineCore
,
EngineCoreProc
,
EngineCoreProcHandle
)
from
vllm.v1.serial_utils
import
PickleEncoder
from
vllm.v1.serial_utils
import
PickleEncoder
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -59,7 +60,7 @@ class EngineCoreClient:
...
@@ -59,7 +60,7 @@ class EngineCoreClient:
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
def
add_request
(
self
,
request
:
EngineCoreRequest
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
...
@@ -71,6 +72,9 @@ class EngineCoreClient:
...
@@ -71,6 +72,9 @@ class EngineCoreClient:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
async
def
add_request_async
(
self
,
request
:
EngineCoreRequest
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
async
def
profile_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
raise
NotImplementedError
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
async
def
abort_requests_async
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
...
@@ -105,7 +109,7 @@ class InprocClient(EngineCoreClient):
def
__del__
(
self
):
def
__del__
(
self
):
self
.
shutdown
()
self
.
shutdown
()
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
engine_core
.
profile
(
is_start
)
self
.
engine_core
.
profile
(
is_start
)
...
@@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
...
@@ -133,7 +137,10 @@ class MPClient(EngineCoreClient):
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
EngineCoreOutputs
)
self
.
decoder
=
msgspec
.
msgpack
.
Decoder
(
EngineCoreOutputs
)
# ZMQ setup.
# ZMQ setup.
self
.
ctx
=
(
zmq
.
asyncio
.
Context
()
if
asyncio_mode
else
zmq
.
Context
())
if
asyncio_mode
:
self
.
ctx
=
zmq
.
asyncio
.
Context
()
else
:
self
.
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
# Path for IPC.
# Path for IPC.
ready_path
=
get_open_zmq_ipc_path
()
ready_path
=
get_open_zmq_ipc_path
()
...
@@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
...
@@ -149,11 +156,13 @@ class MPClient(EngineCoreClient):
self
.
input_socket
.
bind
(
input_path
)
self
.
input_socket
.
bind
(
input_path
)
# Start EngineCore in background process.
# Start EngineCore in background process.
self
.
proc_handle
:
Optional
[
EngineCoreProcHandle
]
self
.
proc_handle
=
EngineCoreProc
.
make_engine_core_process
(
self
.
proc_handle
=
EngineCoreProc
.
make_engine_core_process
(
*
args
,
*
args
,
input_path
=
input_path
,
input_path
=
output_path
=
output_path
,
input_path
,
# type: ignore[misc] # MyPy incorrectly flags duplicate keywords
ready_path
=
ready_path
,
output_path
=
output_path
,
# type: ignore[misc]
ready_path
=
ready_path
,
# type: ignore[misc]
**
kwargs
,
**
kwargs
,
)
)
atexit
.
register
(
self
.
shutdown
)
atexit
.
register
(
self
.
shutdown
)
...
@@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
...
@@ -204,10 +213,8 @@ class SyncMPClient(MPClient):
engine_core_outputs
=
self
.
decoder
.
decode
(
frame
.
buffer
).
outputs
engine_core_outputs
=
self
.
decoder
.
decode
(
frame
.
buffer
).
outputs
return
engine_core_outputs
return
engine_core_outputs
def
_send_input
(
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
# (RequestType, SerializedRequest)
# (RequestType, SerializedRequest)
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
...
@@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
...
@@ -219,7 +226,7 @@ class SyncMPClient(MPClient):
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
def
profile
(
self
,
is_start
=
True
)
->
None
:
def
profile
(
self
,
is_start
:
bool
=
True
)
->
None
:
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
EngineCoreProfile
(
is_start
))
...
@@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
...
@@ -237,10 +244,8 @@ class AsyncMPClient(MPClient):
return
engine_core_outputs
return
engine_core_outputs
async
def
_send_input
(
async
def
_send_input
(
self
,
request_type
:
EngineCoreRequestType
,
self
,
request_type
:
EngineCoreRequestType
,
request
:
EngineCoreRequestUnion
)
->
None
:
request
:
Union
[
EngineCoreRequest
,
EngineCoreProfile
,
List
[
str
]])
->
None
:
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
msg
=
(
request_type
.
value
,
self
.
encoder
.
encode
(
request
))
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
await
self
.
input_socket
.
send_multipart
(
msg
,
copy
=
False
)
...
@@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
...
@@ -252,6 +257,6 @@ class AsyncMPClient(MPClient):
if
len
(
request_ids
)
>
0
:
if
len
(
request_ids
)
>
0
:
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
await
self
.
_send_input
(
EngineCoreRequestType
.
ABORT
,
request_ids
)
async
def
profile
(
self
,
is_start
=
True
)
->
None
:
async
def
profile
_async
(
self
,
is_start
:
bool
=
True
)
->
None
:
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
await
self
.
_send_input
(
EngineCoreRequestType
.
PROFILE
,
EngineCoreProfile
(
is_start
))
EngineCoreProfile
(
is_start
))
vllm/v1/engine/detokenizer.py
View file @
6d917d0e
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -97,7 +97,7 @@ class IncrementalDetokenizer:
...
@@ -97,7 +97,7 @@ class IncrementalDetokenizer:
self
,
self
,
new_token_ids
:
List
[
int
],
new_token_ids
:
List
[
int
],
finish_reason
:
Optional
[
str
],
finish_reason
:
Optional
[
str
],
stop_reason
:
Optional
[
str
],
stop_reason
:
Optional
[
Union
[
int
,
str
,
None
]
],
)
->
Optional
[
RequestOutput
]:
)
->
Optional
[
RequestOutput
]:
"""
"""
Update RequestState for the request_id by:
Update RequestState for the request_id by:
...
...
vllm/v1/engine/llm_engine.py
View file @
6d917d0e
...
@@ -103,7 +103,8 @@ class LLMEngine:
...
@@ -103,7 +103,8 @@ class LLMEngine:
multiprocess_mode
=
enable_multiprocessing
)
multiprocess_mode
=
enable_multiprocessing
)
@
classmethod
@
classmethod
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
):
def
_get_executor_cls
(
cls
,
vllm_config
:
VllmConfig
)
->
Type
[
Executor
]:
executor_class
:
Type
[
Executor
]
distributed_executor_backend
=
(
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
vllm_config
.
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
==
"mp"
:
if
distributed_executor_backend
==
"mp"
:
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
6d917d0e
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
PIL
import
PIL
from
blake3
import
blake3
from
blake3
import
blake3
...
@@ -42,14 +42,14 @@ class MMInputMapperClient:
...
@@ -42,14 +42,14 @@ class MMInputMapperClient:
model_config
)
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_cache
=
LRUDictCache
(
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
]
(
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
# DEBUG: Set to None to disable
self
.
mm_debug_cache_hit_ratio_steps
=
None
self
.
mm_debug_cache_hit_ratio_steps
=
None
self
.
mm_cache_hits
=
0
self
.
mm_cache_hits
=
0
self
.
mm_cache_total
=
0
self
.
mm_cache_total
=
0
def
cache_hit_ratio
(
self
,
steps
)
->
float
:
def
cache_hit_ratio
(
self
,
steps
):
if
self
.
mm_cache_total
>
0
and
self
.
mm_cache_total
%
steps
==
0
:
if
self
.
mm_cache_total
>
0
and
self
.
mm_cache_total
%
steps
==
0
:
logger
.
debug
(
"MMInputMapper: cache_hit_ratio = %.2f "
,
logger
.
debug
(
"MMInputMapper: cache_hit_ratio = %.2f "
,
self
.
mm_cache_hits
/
self
.
mm_cache_total
)
self
.
mm_cache_hits
/
self
.
mm_cache_total
)
...
@@ -60,7 +60,7 @@ class MMInputMapperClient:
...
@@ -60,7 +60,7 @@ class MMInputMapperClient:
mm_hashes
:
Optional
[
List
[
str
]],
mm_hashes
:
Optional
[
List
[
str
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
)
->
List
[
MultiModalKwargs
]:
)
->
Tuple
[
List
[
MultiModalKwargs
]
,
Optional
[
List
[
str
]]]
:
if
precomputed_mm_inputs
is
None
:
if
precomputed_mm_inputs
is
None
:
image_inputs
=
mm_data
[
"image"
]
image_inputs
=
mm_data
[
"image"
]
if
not
isinstance
(
image_inputs
,
list
):
if
not
isinstance
(
image_inputs
,
list
):
...
@@ -72,6 +72,7 @@ class MMInputMapperClient:
...
@@ -72,6 +72,7 @@ class MMInputMapperClient:
# Check if hash is enabled
# Check if hash is enabled
use_hash
=
mm_hashes
is
not
None
use_hash
=
mm_hashes
is
not
None
if
use_hash
:
if
use_hash
:
assert
mm_hashes
is
not
None
assert
num_inputs
==
len
(
assert
num_inputs
==
len
(
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
num_inputs
,
len
(
mm_hashes
))
num_inputs
,
len
(
mm_hashes
))
...
@@ -79,7 +80,7 @@ class MMInputMapperClient:
...
@@ -79,7 +80,7 @@ class MMInputMapperClient:
# Process each image input separately, so that later we can schedule
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes
=
[]
if
use_hash
else
None
ret_hashes
:
Optional
[
List
[
str
]]
=
[]
if
use_hash
else
None
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
for
input_id
in
range
(
num_inputs
):
for
input_id
in
range
(
num_inputs
):
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
...
@@ -88,6 +89,7 @@ class MMInputMapperClient:
...
@@ -88,6 +89,7 @@ class MMInputMapperClient:
mm_hash
=
None
mm_hash
=
None
mm_input
=
None
mm_input
=
None
if
use_hash
:
if
use_hash
:
assert
mm_hashes
is
not
None
mm_hash
=
mm_hashes
[
input_id
]
mm_hash
=
mm_hashes
[
input_id
]
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
...
@@ -105,12 +107,15 @@ class MMInputMapperClient:
...
@@ -105,12 +107,15 @@ class MMInputMapperClient:
if
use_hash
:
if
use_hash
:
# Add to cache
# Add to cache
assert
mm_hash
is
not
None
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
else
:
else
:
self
.
mm_cache_hits
+=
1
self
.
mm_cache_hits
+=
1
mm_input
=
None
# Avoids sending mm_input to Server
mm_input
=
None
# Avoids sending mm_input to Server
if
use_hash
:
if
use_hash
:
assert
mm_hash
is
not
None
assert
ret_hashes
is
not
None
ret_hashes
.
append
(
mm_hash
)
ret_hashes
.
append
(
mm_hash
)
ret_inputs
.
append
(
mm_input
)
ret_inputs
.
append
(
mm_input
)
...
@@ -120,17 +125,18 @@ class MMInputMapperClient:
...
@@ -120,17 +125,18 @@ class MMInputMapperClient:
class
MMInputMapperServer
:
class
MMInputMapperServer
:
def
__init__
(
self
,
):
def
__init__
(
self
,
):
self
.
mm_cache
=
LRUDictCache
(
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
]
(
MM_CACHE_SIZE
)
def
process_inputs
(
def
process_inputs
(
self
,
self
,
mm_inputs
:
List
[
Optional
[
MultiModalKwargs
]],
mm_inputs
:
List
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
List
[
Optional
[
str
]
]
,
mm_hashes
:
List
[
str
],
)
->
List
[
MultiModalKwargs
]:
)
->
List
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
full_mm_inputs
=
[]
full_mm_inputs
=
[]
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
if
mm_input
is
None
:
if
mm_input
is
None
:
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
assert
mm_input
is
not
None
assert
mm_input
is
not
None
...
...
vllm/v1/engine/processor.py
View file @
6d917d0e
...
@@ -56,7 +56,7 @@ class Processor:
...
@@ -56,7 +56,7 @@ class Processor:
request_id
:
str
,
request_id
:
str
,
prompt
:
PromptType
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
vllm/v1/executor/abstract.py
View file @
6d917d0e
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
Optional
,
Tuple
from
typing
import
Tuple
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -28,7 +28,7 @@ class Executor(ABC):
...
@@ -28,7 +28,7 @@ class Executor(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
@@ -38,11 +38,3 @@ class Executor(ABC):
...
@@ -38,11 +38,3 @@ class Executor(ABC):
@
abstractmethod
@
abstractmethod
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
collective_rpc
(
self
,
method
:
str
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
[]:
raise
NotImplementedError
vllm/v1/executor/multiproc_executor.py
View file @
6d917d0e
...
@@ -7,7 +7,7 @@ import time
...
@@ -7,7 +7,7 @@ import time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
multiprocessing.process
import
BaseProcess
from
multiprocessing.process
import
BaseProcess
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
zmq
import
zmq
...
@@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
...
@@ -21,6 +21,7 @@ from vllm.executor.multiproc_worker_utils import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
get_distributed_init_method
,
get_open_port
,
from
vllm.utils
import
(
get_distributed_init_method
,
get_open_port
,
get_open_zmq_ipc_path
)
get_open_zmq_ipc_path
)
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
make_zmq_socket
from
vllm.v1.utils
import
make_zmq_socket
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
@@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
...
@@ -31,7 +32,7 @@ POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
class
MultiprocExecutor
:
class
MultiprocExecutor
(
Executor
)
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
# Call self.shutdown at exit to clean up
# Call self.shutdown at exit to clean up
...
@@ -103,7 +104,7 @@ class MultiprocExecutor:
...
@@ -103,7 +104,7 @@ class MultiprocExecutor:
method
:
str
,
method
:
str
,
timeout
:
Optional
[
float
]
=
None
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
[
]:
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
"""
"""
Execute an RPC call on workers.
Execute an RPC call on workers.
...
@@ -125,7 +126,7 @@ class MultiprocExecutor:
...
@@ -125,7 +126,7 @@ class MultiprocExecutor:
responses
=
[
None
]
*
self
.
world_size
responses
=
[
None
]
*
self
.
world_size
for
w
in
self
.
workers
:
for
w
in
self
.
workers
:
dequeue_timeout
=
timeout
-
(
time
.
monotonic
()
-
start_time
()
dequeue_timeout
=
timeout
-
(
time
.
monotonic
()
-
start_time
)
if
timeout
is
not
None
else
None
)
if
timeout
is
not
None
else
None
status
,
result
=
w
.
worker_response_mq
.
dequeue
(
status
,
result
=
w
.
worker_response_mq
.
dequeue
(
timeout
=
dequeue_timeout
)
timeout
=
dequeue_timeout
)
...
@@ -153,7 +154,7 @@ class MultiprocExecutor:
...
@@ -153,7 +154,7 @@ class MultiprocExecutor:
args
=
(
scheduler_output
,
))[
0
]
args
=
(
scheduler_output
,
))[
0
]
return
model_output
return
model_output
def
profile
(
self
,
is_start
=
True
):
def
profile
(
self
,
is_start
:
bool
=
True
):
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
self
.
collective_rpc
(
"profile"
,
args
=
(
is_start
,
))
return
return
...
@@ -185,7 +186,6 @@ class MultiprocExecutor:
...
@@ -185,7 +186,6 @@ class MultiprocExecutor:
p
.
kill
()
p
.
kill
()
self
.
_cleanup_sockets
()
self
.
_cleanup_sockets
()
self
.
workers
=
None
def
_cleanup_sockets
(
self
):
def
_cleanup_sockets
(
self
):
for
w
in
self
.
workers
:
for
w
in
self
.
workers
:
...
@@ -200,7 +200,8 @@ class MultiprocExecutor:
...
@@ -200,7 +200,8 @@ class MultiprocExecutor:
# again
# again
atexit
.
unregister
(
self
.
shutdown
)
atexit
.
unregister
(
self
.
shutdown
)
"""Properly shut down the executor and its workers"""
"""Properly shut down the executor and its workers"""
if
(
hasattr
(
self
,
'workers'
)
and
self
.
workers
is
not
None
):
if
getattr
(
self
,
'shutting_down'
,
False
):
self
.
shutting_down
=
True
for
w
in
self
.
workers
:
#TODO: not sure if needed
for
w
in
self
.
workers
:
#TODO: not sure if needed
w
.
worker_response_mq
=
None
w
.
worker_response_mq
=
None
self
.
_ensure_worker_termination
()
self
.
_ensure_worker_termination
()
...
...
vllm/v1/executor/uniproc_executor.py
View file @
6d917d0e
...
@@ -4,13 +4,14 @@ from typing import Optional, Tuple
...
@@ -4,13 +4,14 @@ from typing import Optional, Tuple
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu_worker
import
Worker
from
vllm.v1.worker.gpu_worker
import
Worker
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
UniprocExecutor
:
class
UniprocExecutor
(
Executor
)
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
...
@@ -25,7 +26,7 @@ class UniprocExecutor:
...
@@ -25,7 +26,7 @@ class UniprocExecutor:
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
worker
=
self
.
_create_worker
()
self
.
worker
:
Worker
=
self
.
_create_worker
()
self
.
worker
.
initialize
()
self
.
worker
.
initialize
()
self
.
worker
.
load_model
()
self
.
worker
.
load_model
()
...
@@ -75,7 +76,7 @@ class UniprocExecutor:
...
@@ -75,7 +76,7 @@ class UniprocExecutor:
self
.
worker
.
profile
(
is_start
)
self
.
worker
.
profile
(
is_start
)
def
shutdown
(
self
):
def
shutdown
(
self
):
self
.
worker
=
None
pass
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
# UniprocExecutor will always be healthy as long as
# UniprocExecutor will always be healthy as long as
...
...
vllm/v1/request.py
View file @
6d917d0e
...
@@ -52,10 +52,9 @@ class Request:
...
@@ -52,10 +52,9 @@ class Request:
else
:
else
:
self
.
mm_positions
=
[]
self
.
mm_positions
=
[]
# Output of the mm input mapper (e.g., image tensors).
# Output of the mm input mapper (e.g., image tensors).
self
.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
if
self
.
inputs
.
multi_modal_inputs
:
if
self
.
inputs
.
multi_modal_inputs
:
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
else
:
self
.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
...
...
vllm/v1/utils.py
View file @
6d917d0e
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Generic
,
Iterator
,
List
,
TypeVar
,
overload
from
typing
import
(
Any
,
Generic
,
Iterator
,
List
,
Optional
,
TypeVar
,
Union
,
overload
)
import
zmq
import
zmq
...
@@ -11,7 +13,7 @@ logger = init_logger(__name__)
...
@@ -11,7 +13,7 @@ logger = init_logger(__name__)
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
class
ConstantList
(
Generic
[
T
]):
class
ConstantList
(
Generic
[
T
]
,
Sequence
):
def
__init__
(
self
,
x
:
List
[
T
])
->
None
:
def
__init__
(
self
,
x
:
List
[
T
])
->
None
:
self
.
_x
=
x
self
.
_x
=
x
...
@@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
...
@@ -34,29 +36,33 @@ class ConstantList(Generic[T]):
def
clear
(
self
):
def
clear
(
self
):
raise
Exception
(
"Cannot clear a constant list"
)
raise
Exception
(
"Cannot clear a constant list"
)
def
index
(
self
,
item
):
def
index
(
self
,
return
self
.
_x
.
index
(
item
)
item
:
T
,
start
:
int
=
0
,
stop
:
Optional
[
int
]
=
None
)
->
int
:
return
self
.
_x
.
index
(
item
,
start
,
stop
if
stop
is
not
None
else
len
(
self
.
_x
))
@
overload
@
overload
def
__getitem__
(
self
,
item
)
->
T
:
def
__getitem__
(
self
,
item
:
int
)
->
T
:
...
...
@
overload
@
overload
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
List
[
T
]:
def
__getitem__
(
self
,
s
:
slice
,
/
)
->
List
[
T
]:
...
...
def
__getitem__
(
self
,
item
)
:
def
__getitem__
(
self
,
item
:
Union
[
int
,
slice
])
->
Union
[
T
,
List
[
T
]]
:
return
self
.
_x
[
item
]
return
self
.
_x
[
item
]
@
overload
@
overload
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
:
int
,
value
:
T
):
...
...
@
overload
@
overload
def
__setitem__
(
self
,
s
:
slice
,
value
,
/
):
def
__setitem__
(
self
,
s
:
slice
,
value
:
T
,
/
):
...
...
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
:
Union
[
int
,
slice
],
value
:
Union
[
T
,
List
[
T
]]
):
raise
Exception
(
"Cannot set item in a constant list"
)
raise
Exception
(
"Cannot set item in a constant list"
)
def
__delitem__
(
self
,
item
):
def
__delitem__
(
self
,
item
):
...
@@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
...
@@ -73,10 +79,12 @@ class ConstantList(Generic[T]):
@
contextmanager
@
contextmanager
def
make_zmq_socket
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
def
make_zmq_socket
(
path
:
str
,
type
:
Any
)
->
Iterator
[
zmq
.
Socket
]:
# type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
"""Context manager for a ZMQ socket"""
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
# type: ignore[attr-defined]
try
:
try
:
socket
=
ctx
.
socket
(
type
)
socket
=
ctx
.
socket
(
type
)
...
@@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
...
@@ -96,20 +104,24 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
0
)
class
LRUDictCache
:
K
=
TypeVar
(
'K'
)
V
=
TypeVar
(
'V'
)
class
LRUDictCache
(
Generic
[
K
,
V
]):
def
__init__
(
self
,
size
:
int
):
def
__init__
(
self
,
size
:
int
):
self
.
cache
=
OrderedDict
()
self
.
cache
:
OrderedDict
[
K
,
V
]
=
OrderedDict
()
self
.
size
=
size
self
.
size
=
size
def
get
(
self
,
key
,
default
=
None
):
def
get
(
self
,
key
:
K
,
default
=
None
)
->
V
:
if
key
not
in
self
.
cache
:
if
key
not
in
self
.
cache
:
return
default
return
default
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
return
self
.
cache
[
key
]
return
self
.
cache
[
key
]
def
put
(
self
,
key
,
value
):
def
put
(
self
,
key
:
K
,
value
:
V
):
self
.
cache
[
key
]
=
value
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
if
len
(
self
.
cache
)
>
self
.
size
:
if
len
(
self
.
cache
)
>
self
.
size
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
6d917d0e
...
@@ -215,6 +215,7 @@ class InputBatch:
...
@@ -215,6 +215,7 @@ class InputBatch:
# Swap the states.
# Swap the states.
req_id
=
self
.
req_ids
[
last_req_index
]
req_id
=
self
.
req_ids
[
last_req_index
]
assert
req_id
is
not
None
self
.
req_ids
[
empty_index
]
=
req_id
self
.
req_ids
[
empty_index
]
=
req_id
self
.
req_ids
[
last_req_index
]
=
None
self
.
req_ids
[
last_req_index
]
=
None
self
.
req_id_to_index
[
req_id
]
=
empty_index
self
.
req_id_to_index
[
req_id
]
=
empty_index
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6d917d0e
import
gc
import
gc
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Tuple
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -193,9 +193,9 @@ class GPUModelRunner:
...
@@ -193,9 +193,9 @@ class GPUModelRunner:
req_ids_to_add
:
List
[
str
]
=
[]
req_ids_to_add
:
List
[
str
]
=
[]
# Add new requests to the cached states.
# Add new requests to the cached states.
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
for
new_
req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
req_data
.
req_id
req_id
=
new_
req_data
.
req_id
sampling_params
=
req_data
.
sampling_params
sampling_params
=
new_
req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
generator
.
manual_seed
(
sampling_params
.
seed
)
...
@@ -204,25 +204,25 @@ class GPUModelRunner:
...
@@ -204,25 +204,25 @@ class GPUModelRunner:
self
.
requests
[
req_id
]
=
CachedRequestState
(
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
req_data
.
prompt_token_ids
,
prompt_token_ids
=
new_
req_data
.
prompt_token_ids
,
prompt
=
req_data
.
prompt
,
prompt
=
new_
req_data
.
prompt
,
mm_inputs
=
req_data
.
mm_inputs
,
mm_inputs
=
new_
req_data
.
mm_inputs
,
mm_positions
=
req_data
.
mm_positions
,
mm_positions
=
new_
req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
generator
=
generator
,
generator
=
generator
,
block_ids
=
req_data
.
block_ids
,
block_ids
=
new_
req_data
.
block_ids
,
num_computed_tokens
=
req_data
.
num_computed_tokens
,
num_computed_tokens
=
new_
req_data
.
num_computed_tokens
,
output_token_ids
=
[],
output_token_ids
=
[],
)
)
req_ids_to_add
.
append
(
req_id
)
req_ids_to_add
.
append
(
req_id
)
# Update the cached states of the resumed requests.
# Update the cached states of the resumed requests.
for
req_data
in
scheduler_output
.
scheduled_resumed_reqs
:
for
res_
req_data
in
scheduler_output
.
scheduled_resumed_reqs
:
req_id
=
req_data
.
req_id
req_id
=
res_
req_data
.
req_id
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
req_state
.
block_ids
=
req_data
.
block_ids
req_state
.
block_ids
=
res_
req_data
.
block_ids
req_state
.
num_computed_tokens
=
req_data
.
num_computed_tokens
req_state
.
num_computed_tokens
=
res_
req_data
.
num_computed_tokens
req_ids_to_add
.
append
(
req_id
)
req_ids_to_add
.
append
(
req_id
)
# Add the new or resumed requests to the persistent batch.
# Add the new or resumed requests to the persistent batch.
...
@@ -259,6 +259,7 @@ class GPUModelRunner:
...
@@ -259,6 +259,7 @@ class GPUModelRunner:
num_scheduled_tokens
=
[]
num_scheduled_tokens
=
[]
max_num_scheduled_tokens
=
0
max_num_scheduled_tokens
=
0
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens
.
append
(
num_tokens
)
num_scheduled_tokens
.
append
(
num_tokens
)
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
max_num_scheduled_tokens
=
max
(
max_num_scheduled_tokens
,
...
@@ -373,7 +374,7 @@ class GPUModelRunner:
...
@@ -373,7 +374,7 @@ class GPUModelRunner:
# Batch the multi-modal inputs.
# Batch the multi-modal inputs.
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
mm_inputs
:
List
[
MultiModalKwargs
]
=
[]
req_input_ids
:
List
[
Tuple
[
int
,
int
]]
=
[]
req_input_ids
:
List
[
Tuple
[
str
,
int
]]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
for
input_id
in
encoder_input_ids
:
for
input_id
in
encoder_input_ids
:
...
@@ -406,6 +407,7 @@ class GPUModelRunner:
...
@@ -406,6 +407,7 @@ class GPUModelRunner:
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
encoder_outputs
:
List
[
torch
.
Tensor
]
=
[]
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]:
assert
req_id
is
not
None
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
req_id
]
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
...
@@ -514,6 +516,7 @@ class GPUModelRunner:
...
@@ -514,6 +516,7 @@ class GPUModelRunner:
# the requests one by one. Optimize.
# the requests one by one. Optimize.
num_reqs
=
self
.
input_batch
.
num_reqs
num_reqs
=
self
.
input_batch
.
num_reqs
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
[:
num_reqs
]):
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
[:
num_reqs
]):
assert
req_id
is
not
None
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
scheduler_output
.
num_scheduled_tokens
[
req_id
])
...
@@ -539,8 +542,15 @@ class GPUModelRunner:
...
@@ -539,8 +542,15 @@ class GPUModelRunner:
logprobs
=
None
logprobs
=
None
else
:
else
:
logprobs
=
sampler_output
.
logprobs
.
cpu
()
logprobs
=
sampler_output
.
logprobs
.
cpu
()
# num_reqs entries should be non-None
assert
all
(
req_id
is
not
None
for
req_id
in
self
.
input_batch
.
req_ids
[:
num_reqs
]),
"req_ids contains None"
req_ids
=
cast
(
List
[
str
],
self
.
input_batch
.
req_ids
[:
num_reqs
])
model_runner_output
=
ModelRunnerOutput
(
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
[:
num_reqs
]
,
req_ids
=
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
logprob_token_ids_cpu
=
logprob_token_ids
,
logprob_token_ids_cpu
=
logprob_token_ids
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment