Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96ae75ad
Commit
96ae75ad
authored
Jan 04, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.6.6.post1' into v0.6.6.post1-dev
parents
f9f4a735
2339d59f
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1257 additions
and
306 deletions
+1257
-306
vllm/transformers_utils/s3_utils.py
vllm/transformers_utils/s3_utils.py
+151
-0
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+1
-1
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+1
-1
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+4
-1
vllm/transformers_utils/utils.py
vllm/transformers_utils/utils.py
+4
-0
vllm/utils.py
vllm/utils.py
+142
-42
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+52
-35
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+105
-10
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+70
-105
vllm/v1/engine/async_stream.py
vllm/v1/engine/async_stream.py
+0
-55
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+5
-5
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+10
-4
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+52
-22
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+11
-23
vllm/v1/executor/ray_executor.py
vllm/v1/executor/ray_executor.py
+342
-0
vllm/v1/executor/ray_utils.py
vllm/v1/executor/ray_utils.py
+271
-0
vllm/v1/request.py
vllm/v1/request.py
+23
-1
vllm/v1/sample/metadata.py
vllm/v1/sample/metadata.py
+11
-1
vllm/v1/sample/ops/__init__.py
vllm/v1/sample/ops/__init__.py
+0
-0
No files found.
vllm/transformers_utils/s3_utils.py
0 → 100644
View file @
96ae75ad
import
fnmatch
import
os
import
shutil
import
signal
import
tempfile
from
pathlib
import
Path
from
typing
import
Optional
from
vllm.utils
import
PlaceholderModule
try
:
import
boto3
except
ImportError
:
boto3
=
PlaceholderModule
(
"boto3"
)
# type: ignore[assignment]
def
_filter_allow
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
_filter_ignore
(
paths
:
list
[
str
],
patterns
:
list
[
str
])
->
list
[
str
]:
return
[
path
for
path
in
paths
if
not
any
(
fnmatch
.
fnmatch
(
path
,
pattern
)
for
pattern
in
patterns
)
]
def
glob
(
s3
=
None
,
path
:
str
=
""
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
)
->
list
[
str
]:
"""
List full file names from S3 path and filter by allow pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
Returns:
list[str]: List of full S3 paths allowed by the pattern
"""
if
s3
is
None
:
s3
=
boto3
.
client
(
"s3"
)
bucket_name
,
_
,
paths
=
list_files
(
s3
,
path
=
path
,
allow_pattern
=
allow_pattern
)
return
[
f
"s3://
{
bucket_name
}
/
{
path
}
"
for
path
in
paths
]
def
list_files
(
s3
,
path
:
str
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
)
->
tuple
[
str
,
str
,
list
[
str
]]:
"""
List files from S3 path and filter by pattern.
Args:
s3: S3 client to use.
path: The S3 path to list from.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
Returns:
tuple[str, str, list[str]]: A tuple where:
- The first element is the bucket name
- The second element is string represent the bucket
and the prefix as a dir like string
- The third element is a list of files allowed or
disallowed by pattern
"""
parts
=
path
.
removeprefix
(
's3://'
).
split
(
'/'
)
prefix
=
'/'
.
join
(
parts
[
1
:])
bucket_name
=
parts
[
0
]
objects
=
s3
.
list_objects_v2
(
Bucket
=
bucket_name
,
Prefix
=
prefix
)
paths
=
[
obj
[
'Key'
]
for
obj
in
objects
.
get
(
'Contents'
,
[])]
paths
=
_filter_ignore
(
paths
,
[
"*/"
])
if
allow_pattern
is
not
None
:
paths
=
_filter_allow
(
paths
,
allow_pattern
)
if
ignore_pattern
is
not
None
:
paths
=
_filter_ignore
(
paths
,
ignore_pattern
)
return
bucket_name
,
prefix
,
paths
class
S3Model
:
"""
A class representing a S3 model mirrored into a temporary directory.
Attributes:
s3: S3 client.
dir: The temporary created directory.
Methods:
pull_files(): Pull model from S3 to the temporary directory.
"""
def
__init__
(
self
)
->
None
:
self
.
s3
=
boto3
.
client
(
's3'
)
for
sig
in
(
signal
.
SIGINT
,
signal
.
SIGTERM
):
existing_handler
=
signal
.
getsignal
(
sig
)
signal
.
signal
(
sig
,
self
.
_close_by_signal
(
existing_handler
))
self
.
dir
=
tempfile
.
mkdtemp
()
def
__del__
(
self
):
self
.
_close
()
def
_close
(
self
)
->
None
:
if
os
.
path
.
exists
(
self
.
dir
):
shutil
.
rmtree
(
self
.
dir
)
def
_close_by_signal
(
self
,
existing_handler
=
None
):
def
new_handler
(
signum
,
frame
):
self
.
_close
()
if
existing_handler
:
existing_handler
(
signum
,
frame
)
return
new_handler
def
pull_files
(
self
,
s3_model_path
:
str
=
""
,
allow_pattern
:
Optional
[
list
[
str
]]
=
None
,
ignore_pattern
:
Optional
[
list
[
str
]]
=
None
)
->
None
:
"""
Pull files from S3 storage into the temporary directory.
Args:
s3_model_path: The S3 path of the model.
allow_pattern: A list of patterns of which files to pull.
ignore_pattern: A list of patterns of which files not to pull.
"""
bucket_name
,
base_dir
,
files
=
list_files
(
self
.
s3
,
s3_model_path
,
allow_pattern
,
ignore_pattern
)
if
len
(
files
)
==
0
:
return
for
file
in
files
:
destination_file
=
self
.
dir
+
file
.
removeprefix
(
base_dir
)
local_dir
=
Path
(
destination_file
).
parent
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
self
.
s3
.
download_file
(
bucket_name
,
file
,
destination_file
)
vllm/transformers_utils/tokenizer.py
View file @
96ae75ad
...
@@ -132,7 +132,7 @@ def get_tokenizer(
...
@@ -132,7 +132,7 @@ def get_tokenizer(
if
is_from_mistral_org
and
tokenizer_mode
!=
"mistral"
:
if
is_from_mistral_org
and
tokenizer_mode
!=
"mistral"
:
warnings
.
warn
(
warnings
.
warn
(
'It is strongly recommended to run mistral models with '
'It is strongly recommended to run mistral models with '
'`--tokenizer
_
mode "mistral"` to ensure correct '
'`--tokenizer
-
mode "mistral"` to ensure correct '
'encoding and decoding.'
,
'encoding and decoding.'
,
FutureWarning
,
FutureWarning
,
stacklevel
=
2
)
stacklevel
=
2
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
96ae75ad
...
@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
...
@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
max_input_length
=
max_input_length
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
max_loras
=
tokenizer_config
.
get
(
"max_loras"
,
0
)
max_loras
=
tokenizer_config
.
get
(
"max_loras"
,
0
)
self
.
lora_tokenizers
=
LRUCache
[
AnyTokenizer
](
self
.
lora_tokenizers
=
LRUCache
[
int
,
AnyTokenizer
](
capacity
=
max
(
max_loras
,
max_num_seqs
)
if
enable_lora
else
0
)
capacity
=
max
(
max_loras
,
max_num_seqs
)
if
enable_lora
else
0
)
@
classmethod
@
classmethod
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
96ae75ad
...
@@ -314,12 +314,15 @@ class MistralTokenizer:
...
@@ -314,12 +314,15 @@ class MistralTokenizer:
if
regular_tokens
:
if
regular_tokens
:
decoded_list
.
append
(
decoded_list
.
append
(
self
.
decode
(
regular_tokens
))
# type: ignore
self
.
tokenizer
.
decode
(
regular_tokens
))
# type: ignore
decoded
=
''
.
join
(
decoded_list
)
decoded
=
''
.
join
(
decoded_list
)
return
decoded
return
decoded
# WARN: Outlines logits processors can overwrite this method.
# See: guided_decoding/outlines_logits_processors.py::_adapt_tokenizer
# for more.
def
decode
(
self
,
def
decode
(
self
,
ids
:
Union
[
List
[
int
],
int
],
ids
:
Union
[
List
[
int
],
int
],
skip_special_tokens
:
bool
=
True
)
->
str
:
skip_special_tokens
:
bool
=
True
)
->
str
:
...
...
vllm/transformers_utils/utils.py
View file @
96ae75ad
...
@@ -3,6 +3,10 @@ from pathlib import Path
...
@@ -3,6 +3,10 @@ from pathlib import Path
from
typing
import
Union
from
typing
import
Union
def
is_s3
(
model_or_path
:
str
)
->
bool
:
return
model_or_path
.
lower
().
startswith
(
's3://'
)
def
check_gguf_file
(
model
:
Union
[
str
,
PathLike
])
->
bool
:
def
check_gguf_file
(
model
:
Union
[
str
,
PathLike
])
->
bool
:
"""Check if the file is a GGUF model."""
"""Check if the file is a GGUF model."""
model
=
Path
(
model
)
model
=
Path
(
model
)
...
...
vllm/utils.py
View file @
96ae75ad
...
@@ -6,10 +6,13 @@ import datetime
...
@@ -6,10 +6,13 @@ import datetime
import
enum
import
enum
import
gc
import
gc
import
getpass
import
getpass
import
importlib.metadata
import
importlib.util
import
importlib.util
import
inspect
import
inspect
import
ipaddress
import
ipaddress
import
os
import
os
import
re
import
resource
import
signal
import
signal
import
socket
import
socket
import
subprocess
import
subprocess
...
@@ -21,14 +24,13 @@ import uuid
...
@@ -21,14 +24,13 @@ import uuid
import
warnings
import
warnings
import
weakref
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
collections.abc
import
Iterable
,
Mapping
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
Hashable
,
List
,
Literal
,
Dict
,
Generator
,
Generic
,
Hashable
,
List
,
Literal
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
overload
)
from
uuid
import
uuid4
from
uuid
import
uuid4
import
numpy
as
np
import
numpy
as
np
...
@@ -52,7 +54,7 @@ logger = init_logger(__name__)
...
@@ -52,7 +54,7 @@ logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/usage/compatibility_matrix.
rst
# Reminder: Please update docs/source/usage/compatibility_matrix.
md
# If the feature combo become valid
# If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA
=
\
STR_NOT_IMPL_ENC_DEC_SWA
=
\
...
@@ -154,10 +156,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
...
@@ -154,10 +156,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
}
P
=
ParamSpec
(
'P'
)
P
=
ParamSpec
(
'P'
)
K
=
TypeVar
(
"K"
)
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
U
=
TypeVar
(
"U"
)
U
=
TypeVar
(
"U"
)
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
class
_Sentinel
:
class
_Sentinel
:
...
...
...
@@ -190,50 +194,48 @@ class Counter:
...
@@ -190,50 +194,48 @@ class Counter:
self
.
counter
=
0
self
.
counter
=
0
class
LRUCache
(
Generic
[
T
]):
class
LRUCache
(
Generic
[
_K
,
_V
]):
def
__init__
(
self
,
capacity
:
int
):
def
__init__
(
self
,
capacity
:
int
)
->
None
:
self
.
cache
:
OrderedDict
[
Hashable
,
T
]
=
OrderedDict
()
self
.
cache
=
OrderedDict
[
_K
,
_V
]
()
self
.
pinned_items
:
Set
[
Hashable
]
=
set
()
self
.
pinned_items
=
set
[
_K
]
()
self
.
capacity
=
capacity
self
.
capacity
=
capacity
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
def
__contains__
(
self
,
key
:
_K
)
->
bool
:
return
key
in
self
.
cache
return
key
in
self
.
cache
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
Hashable
)
->
T
:
def
__getitem__
(
self
,
key
:
_K
)
->
_V
:
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
return
value
return
value
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
__setitem__
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
put
(
key
,
value
)
self
.
put
(
key
,
value
)
def
__delitem__
(
self
,
key
:
Hashable
)
->
None
:
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
self
.
pop
(
key
)
self
.
pop
(
key
)
def
touch
(
self
,
key
:
Hashable
)
->
None
:
def
touch
(
self
,
key
:
_K
)
->
None
:
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
def
get
(
self
,
def
get
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
key
:
Hashable
,
value
:
Optional
[
_V
]
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
value
:
Optional
[
T
]
if
key
in
self
.
cache
:
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
else
:
else
:
value
=
default
_value
value
=
default
return
value
return
value
def
put
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
cache
[
key
]
=
value
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
self
.
_remove_old_if_needed
()
def
pin
(
self
,
key
:
Hashable
)
->
None
:
def
pin
(
self
,
key
:
_K
)
->
None
:
"""
"""
Pins a key in the cache preventing it from being
Pins a key in the cache preventing it from being
evicted in the LRU order.
evicted in the LRU order.
...
@@ -242,13 +244,13 @@ class LRUCache(Generic[T]):
...
@@ -242,13 +244,13 @@ class LRUCache(Generic[T]):
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
self
.
pinned_items
.
add
(
key
)
self
.
pinned_items
.
add
(
key
)
def
_unpin
(
self
,
key
:
Hashable
)
->
None
:
def
_unpin
(
self
,
key
:
_K
)
->
None
:
self
.
pinned_items
.
remove
(
key
)
self
.
pinned_items
.
remove
(
key
)
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
])
:
def
_on_remove
(
self
,
key
:
_K
,
value
:
Optional
[
_V
])
->
None
:
pass
pass
def
remove_oldest
(
self
,
remove_pinned
=
False
)
:
def
remove_oldest
(
self
,
*
,
remove_pinned
:
bool
=
False
)
->
None
:
if
not
self
.
cache
:
if
not
self
.
cache
:
return
return
...
@@ -262,17 +264,15 @@ class LRUCache(Generic[T]):
...
@@ -262,17 +264,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache."
)
"cannot remove oldest from the cache."
)
else
:
else
:
lru_key
=
next
(
iter
(
self
.
cache
))
lru_key
=
next
(
iter
(
self
.
cache
))
self
.
pop
(
lru_key
)
self
.
pop
(
lru_key
)
# type: ignore
def
_remove_old_if_needed
(
self
)
->
None
:
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
self
.
remove_oldest
()
def
pop
(
self
,
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
run_on_remove
=
key
in
self
.
cache
run_on_remove
=
key
in
self
.
cache
value
:
Optional
[
T
]
=
self
.
cache
.
pop
(
key
,
default
_value
)
value
=
self
.
cache
.
pop
(
key
,
default
)
# remove from pinned items
# remove from pinned items
if
key
in
self
.
pinned_items
:
if
key
in
self
.
pinned_items
:
self
.
_unpin
(
key
)
self
.
_unpin
(
key
)
...
@@ -280,7 +280,7 @@ class LRUCache(Generic[T]):
...
@@ -280,7 +280,7 @@ class LRUCache(Generic[T]):
self
.
_on_remove
(
key
,
value
)
self
.
_on_remove
(
key
,
value
)
return
value
return
value
def
clear
(
self
):
def
clear
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
0
:
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
cache
.
clear
()
self
.
cache
.
clear
()
...
@@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
...
@@ -775,7 +775,7 @@ def get_dtype_size(dtype: torch.dtype) -> int:
# `collections` helpers
# `collections` helpers
def
is_list_of
(
def
is_list_of
(
value
:
object
,
value
:
object
,
typ
:
Type
[
T
],
typ
:
Union
[
type
[
T
],
tuple
[
type
[
T
],
...]
],
*
,
*
,
check
:
Literal
[
"first"
,
"all"
]
=
"first"
,
check
:
Literal
[
"first"
,
"all"
]
=
"first"
,
)
->
TypeIs
[
List
[
T
]]:
)
->
TypeIs
[
List
[
T
]]:
...
@@ -843,10 +843,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
...
@@ -843,10 +843,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
"""
"""
Unlike :class:`itertools.groupby`, groups are not broken by
Unlike :class:`itertools.groupby`, groups are not broken by
...
@@ -1282,6 +1278,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
...
@@ -1282,6 +1278,7 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
def
supports_kw
(
def
supports_kw
(
callable
:
Callable
[...,
object
],
callable
:
Callable
[...,
object
],
kw_name
:
str
,
kw_name
:
str
,
*
,
requires_kw_only
:
bool
=
False
,
requires_kw_only
:
bool
=
False
,
allow_var_kwargs
:
bool
=
True
,
allow_var_kwargs
:
bool
=
True
,
)
->
bool
:
)
->
bool
:
...
@@ -1326,6 +1323,8 @@ def resolve_mm_processor_kwargs(
...
@@ -1326,6 +1323,8 @@ def resolve_mm_processor_kwargs(
init_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
init_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
inference_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
inference_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
callable
:
Callable
[...,
object
],
callable
:
Callable
[...,
object
],
*
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
allow_var_kwargs
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
...
@@ -1344,11 +1343,17 @@ def resolve_mm_processor_kwargs(
...
@@ -1344,11 +1343,17 @@ def resolve_mm_processor_kwargs(
runtime_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
runtime_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
callable
,
callable
,
overrides
=
inference_kwargs
,
overrides
=
inference_kwargs
,
allow_var_kwargs
=
allow_var_kwargs
)
requires_kw_only
=
requires_kw_only
,
allow_var_kwargs
=
allow_var_kwargs
,
)
# Filter init time multimodal processor kwargs provided
# Filter init time multimodal processor kwargs provided
init_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
init_mm_kwargs
=
get_allowed_kwarg_only_overrides
(
callable
,
overrides
=
init_kwargs
,
allow_var_kwargs
=
allow_var_kwargs
)
callable
,
overrides
=
init_kwargs
,
requires_kw_only
=
requires_kw_only
,
allow_var_kwargs
=
allow_var_kwargs
,
)
# Merge the final processor kwargs, prioritizing inference
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
# time values over the initialization time values.
...
@@ -1359,6 +1364,8 @@ def resolve_mm_processor_kwargs(
...
@@ -1359,6 +1364,8 @@ def resolve_mm_processor_kwargs(
def
get_allowed_kwarg_only_overrides
(
def
get_allowed_kwarg_only_overrides
(
callable
:
Callable
[...,
object
],
callable
:
Callable
[...,
object
],
overrides
:
Optional
[
Mapping
[
str
,
object
]],
overrides
:
Optional
[
Mapping
[
str
,
object
]],
*
,
requires_kw_only
:
bool
=
True
,
allow_var_kwargs
:
bool
=
False
,
allow_var_kwargs
:
bool
=
False
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
"""
"""
...
@@ -1390,16 +1397,21 @@ def get_allowed_kwarg_only_overrides(
...
@@ -1390,16 +1397,21 @@ def get_allowed_kwarg_only_overrides(
for
kwarg_name
,
val
in
overrides
.
items
()
for
kwarg_name
,
val
in
overrides
.
items
()
if
supports_kw
(
callable
,
if
supports_kw
(
callable
,
kwarg_name
,
kwarg_name
,
requires_kw_only
=
True
,
requires_kw_only
=
requires_kw_only
,
allow_var_kwargs
=
allow_var_kwargs
)
allow_var_kwargs
=
allow_var_kwargs
)
}
}
# If anything is dropped, log a warning
# If anything is dropped, log a warning
dropped_keys
=
overrides
.
keys
()
-
filtered_overrides
.
keys
()
dropped_keys
=
overrides
.
keys
()
-
filtered_overrides
.
keys
()
if
dropped_keys
:
if
dropped_keys
:
logger
.
warning
(
if
requires_kw_only
:
"The following intended overrides are not keyword-only args "
logger
.
warning
(
"and and will be dropped: %s"
,
dropped_keys
)
"The following intended overrides are not keyword-only args "
"and and will be dropped: %s"
,
dropped_keys
)
else
:
logger
.
warning
(
"The following intended overrides are not keyword args "
"and and will be dropped: %s"
,
dropped_keys
)
return
filtered_overrides
return
filtered_overrides
...
@@ -1628,6 +1640,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
...
@@ -1628,6 +1640,67 @@ def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
return
module
return
module
@
lru_cache
(
maxsize
=
None
)
def
get_vllm_optional_dependencies
():
metadata
=
importlib
.
metadata
.
metadata
(
"vllm"
)
requirements
=
metadata
.
get_all
(
"Requires-Dist"
,
[])
extras
=
metadata
.
get_all
(
"Provides-Extra"
,
[])
return
{
extra
:
[
re
.
split
(
r
";|>=|<=|=="
,
req
)[
0
]
for
req
in
requirements
if
req
.
endswith
(
f
'extra == "
{
extra
}
"'
)
]
for
extra
in
extras
}
@
dataclass
(
frozen
=
True
)
class
PlaceholderModule
:
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
"""
name
:
str
def
placeholder_attr
(
self
,
attr_path
:
str
):
return
_PlaceholderModuleAttr
(
self
,
attr_path
)
def
__getattr__
(
self
,
key
:
str
):
name
=
self
.
name
try
:
importlib
.
import_module
(
self
.
name
)
except
ImportError
as
exc
:
for
extra
,
names
in
get_vllm_optional_dependencies
().
items
():
if
name
in
names
:
msg
=
f
"Please install vllm[
{
extra
}
] for
{
extra
}
support"
raise
ImportError
(
msg
)
from
exc
raise
exc
raise
AssertionError
(
"PlaceholderModule should not be used "
"when the original module can be imported"
)
@
dataclass
(
frozen
=
True
)
class
_PlaceholderModuleAttr
:
module
:
PlaceholderModule
attr_path
:
str
def
placeholder_attr
(
self
,
attr_path
:
str
):
return
_PlaceholderModuleAttr
(
self
.
module
,
f
"
{
self
.
attr_path
}
.
{
attr_path
}
"
)
def
__getattr__
(
self
,
key
:
str
):
getattr
(
self
.
module
,
f
"
{
self
.
attr_path
}
.
{
key
}
"
)
raise
AssertionError
(
"PlaceholderModule should not be used "
"when the original module can be imported"
)
# create a library to hold the custom op
# create a library to hold the custom op
vllm_lib
=
Library
(
"vllm"
,
"FRAGMENT"
)
# noqa
vllm_lib
=
Library
(
"vllm"
,
"FRAGMENT"
)
# noqa
...
@@ -1655,8 +1728,18 @@ def direct_register_custom_op(
...
@@ -1655,8 +1728,18 @@ def direct_register_custom_op(
library object. If you want to bind the operator to a different library,
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
make sure the library object is alive when the operator is used.
"""
"""
if
is_in_doc_build
()
or
not
supports_custom_op
()
:
if
is_in_doc_build
():
return
return
if
not
supports_custom_op
():
assert
not
current_platform
.
is_cuda_alike
(),
(
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies."
)
return
import
torch.library
import
torch.library
if
hasattr
(
torch
.
library
,
"infer_schema"
):
if
hasattr
(
torch
.
library
,
"infer_schema"
):
schema_str
=
torch
.
library
.
infer_schema
(
op_func
,
schema_str
=
torch
.
library
.
infer_schema
(
op_func
,
...
@@ -1823,3 +1906,20 @@ def memory_profiling(
...
@@ -1823,3 +1906,20 @@ def memory_profiling(
result
.
non_torch_increase_in_bytes
=
current_cuda_memory_bytes
-
baseline_memory_in_bytes
-
weights_memory_in_bytes
-
diff
.
torch_memory_in_bytes
# noqa
result
.
non_torch_increase_in_bytes
=
current_cuda_memory_bytes
-
baseline_memory_in_bytes
-
weights_memory_in_bytes
-
diff
.
torch_memory_in_bytes
# noqa
result
.
profile_time
=
diff
.
timestamp
result
.
profile_time
=
diff
.
timestamp
result
.
non_kv_cache_memory_in_bytes
=
result
.
non_torch_increase_in_bytes
+
result
.
torch_peak_increase_in_bytes
+
result
.
weights_memory_in_bytes
# noqa
result
.
non_kv_cache_memory_in_bytes
=
result
.
non_torch_increase_in_bytes
+
result
.
torch_peak_increase_in_bytes
+
result
.
weights_memory_in_bytes
# noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
if
current_soft
<
target_soft_limit
:
try
:
resource
.
setrlimit
(
resource_type
,
(
target_soft_limit
,
current_hard
))
except
ValueError
as
e
:
logger
.
warning
(
"Found ulimit of %s and failed to automatically increase"
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n"
,
current_soft
,
e
)
vllm/v1/core/kv_cache_manager.py
View file @
96ae75ad
...
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
...
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
hash_block_tokens
,
KVCacheBlock
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_request_tokens
)
hash_request_tokens
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -83,10 +85,12 @@ class KVCacheManager:
...
@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks
=
[]
computed_blocks
=
[]
# TODO(rickyx): potentially we could cache this so we don't have to
# The block hashes for the request may already be computed
# recompute it every time.
# if the request was preempted and resumed.
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
if
not
request
.
kv_block_hashes
:
request
.
all_token_ids
)
request
.
set_kv_block_hashes
(
hash_request_tokens
(
self
.
block_size
,
request
))
block_hashes
=
request
.
kv_block_hashes
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash is not
...
@@ -197,23 +201,15 @@ class KVCacheManager:
...
@@ -197,23 +201,15 @@ class KVCacheManager:
f
"num_tokens must be greater than 0, got
{
num_tokens
}
"
)
f
"num_tokens must be greater than 0, got
{
num_tokens
}
"
)
# Touch the computed blocks to make sure they won't be evicted.
# Touch the computed blocks to make sure they won't be evicted.
num_evictable_computed_blocks
=
0
if
self
.
enable_caching
:
if
self
.
enable_caching
:
self
.
_touch
(
computed_blocks
)
self
.
_touch
(
computed_blocks
)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks
=
len
(
[
blk
for
blk
in
computed_blocks
if
blk
.
ref_cnt
==
0
])
else
:
else
:
assert
not
computed_blocks
,
(
assert
not
computed_blocks
,
(
"Computed blocks should be empty when "
"Computed blocks should be empty when "
"prefix caching is disabled"
)
"prefix caching is disabled"
)
num_required_blocks
=
cdiv
(
num_tokens
,
self
.
block_size
)
num_required_blocks
=
cdiv
(
num_tokens
,
self
.
block_size
)
if
(
num_required_blocks
>
self
.
free_block_queue
.
num_free_blocks
-
if
(
num_required_blocks
>
self
.
free_block_queue
.
num_free_blocks
):
num_evictable_computed_blocks
):
# Cannot allocate new blocks.
# Cannot allocate new blocks.
return
None
return
None
...
@@ -221,8 +217,7 @@ class KVCacheManager:
...
@@ -221,8 +217,7 @@ class KVCacheManager:
# preallocated blocks.
# preallocated blocks.
num_new_blocks
=
min
(
num_new_blocks
=
min
(
num_required_blocks
+
self
.
num_preallocate_blocks
,
num_required_blocks
+
self
.
num_preallocate_blocks
,
self
.
free_block_queue
.
num_free_blocks
-
self
.
free_block_queue
.
num_free_blocks
,
num_evictable_computed_blocks
,
# Should not exceed the maximum number of blocks per request.
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# [..., max_num_blocks_per_req].
...
@@ -242,14 +237,16 @@ class KVCacheManager:
...
@@ -242,14 +237,16 @@ class KVCacheManager:
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
self
.
_cache_full_blocks
(
new_full_blocks
=
self
.
req_to_blocks
[
request
=
request
,
request
.
request_id
][
len
(
computed_blocks
):
num_full_blocks
]
blk_start_idx
=
len
(
computed_blocks
),
if
new_full_blocks
:
# The new full blocks are the full blocks that are not computed.
self
.
_cache_full_blocks
(
full_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
request
=
request
,
[
len
(
computed_blocks
):
num_full_blocks
],
blk_start_idx
=
len
(
computed_blocks
),
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
# The new full blocks are the full blocks that are not computed.
)
full_blocks
=
new_full_blocks
,
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
)
return
new_blocks
return
new_blocks
...
@@ -376,6 +373,8 @@ class KVCacheManager:
...
@@ -376,6 +373,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
prev_block: The previous block in the chain.
"""
"""
num_cached_block_hashes
=
len
(
request
.
kv_block_hashes
)
# Update the new blocks with the block hashes through the chain.
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value
=
None
prev_block_hash_value
=
None
if
prev_block
is
not
None
:
if
prev_block
is
not
None
:
...
@@ -387,17 +386,35 @@ class KVCacheManager:
...
@@ -387,17 +386,35 @@ class KVCacheManager:
for
i
,
blk
in
enumerate
(
full_blocks
):
for
i
,
blk
in
enumerate
(
full_blocks
):
blk_idx
=
blk_start_idx
+
i
blk_idx
=
blk_start_idx
+
i
block_tokens
=
request
.
all_token_ids
[
blk_idx
*
if
blk_idx
<
num_cached_block_hashes
:
self
.
block_size
:(
blk_idx
+
# The block hash may already be computed in
1
)
*
# "get_computed_blocks" if the tokens are not generated by
self
.
block_size
]
# this request (either the prompt tokens or the previously
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
# generated tokens with preemption). In this case we simply
f
"Expected
{
self
.
block_size
}
tokens, got
{
len
(
block_tokens
)
}
"
# reuse the block hash.
f
"at
{
blk_idx
}
th block for request "
block_hash
=
request
.
kv_block_hashes
[
blk_idx
]
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
else
:
# Otherwise compute the block hash and cache it in the request
# Compute the hash of the current block.
# in case it will be preempted in the future.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
)
start_token_idx
=
blk_idx
*
self
.
block_size
end_token_idx
=
(
blk_idx
+
1
)
*
self
.
block_size
block_tokens
=
request
.
all_token_ids
[
start_token_idx
:
end_token_idx
]
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
f
"Expected
{
self
.
block_size
}
tokens, got "
f
"
{
len
(
block_tokens
)
}
at
{
blk_idx
}
th block for request "
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
,
extra_keys
)
request
.
append_kv_block_hashes
(
block_hash
)
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
blk
.
block_hash
=
block_hash
...
...
vllm/v1/core/kv_cache_utils.py
View file @
96ae75ad
"""KV-Cache Utilities."""
"""KV-Cache Utilities."""
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
BlockHashType
(
NamedTuple
):
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block
and
the token IDs in the block.
"""Hash value of a block
(int),
the token IDs in the block
, and extra keys
.
The reason we keep a tuple of token IDs is to make sure
no hash
The reason we keep a tuple of token IDs
and extra keys
is to make sure
collision happens when the hash value is the same.
no hash
collision happens when the hash value is the same.
"""
"""
# Hash value of the block in an integer.
hash_value
:
int
hash_value
:
int
# Token IDs in the block.
token_ids
:
Tuple
[
int
,
...]
token_ids
:
Tuple
[
int
,
...]
# Extra keys for the block.
extra_keys
:
Optional
[
Any
]
=
None
@
dataclass
@
dataclass
...
@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
...
@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return
ret
return
ret
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
def
generate_block_hash_extra_keys
(
curr_block_token_ids
:
Sequence
[
int
])
->
BlockHashType
:
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
)
->
Tuple
[
Optional
[
Tuple
[
Any
,
...]],
int
]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
not
mm_positions
:
return
None
,
start_mm_idx
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False."
)
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
return
None
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if
start_mm_idx
<
0
:
assert
-
start_mm_idx
<=
len
(
mm_positions
)
start_mm_idx
=
len
(
mm_positions
)
+
start_mm_idx
extra_keys
=
[]
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
][
"offset"
]
length
=
mm_positions
[
curr_mm_idx
][
"length"
]
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
curr_mm_idx
+=
1
continue
# The block contains the current mm input.
mm_start
=
max
(
0
,
start_token_idx
-
offset
)
extra_keys
.
append
((
mm_hashes
[
curr_mm_idx
],
mm_start
))
if
end_token_idx
>=
offset
+
length
:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx
+=
1
else
:
# Otherwise this block is done with mm inputs.
break
else
:
# This block has not reached the current mm input.
break
return
tuple
(
extra_keys
),
curr_mm_idx
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Sequence
[
int
],
extra_keys
:
Optional
[
Tuple
[
Any
,
...]]
=
None
)
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
prefix caching. We use LRU cache for this function to avoid recomputing
...
@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
...
@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
if this is the first block.
curr_block_token_ids: A list of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
Returns:
The hash value of the block and the token ids in the block.
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
The entire tuple is used as the hash key of the block.
"""
"""
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
tuple
(
curr_block_token_ids
))
tuple
(
curr_block_token_ids
)
,
extra_keys
)
def
hash_request_tokens
(
block_size
:
int
,
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
S
eque
nce
[
int
]
)
->
List
[
BlockHashType
]:
request
:
R
eque
st
)
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
token IDs. The hash value is used for prefix caching.
Args:
Args:
block_size: The size of each block.
block_size: The size of each block.
token_ids: A sequence of token ids in t
he request.
request: T
he request
object
.
Returns:
Returns:
The list of computed hash values.
The list of computed hash values.
"""
"""
token_ids
=
request
.
all_token_ids
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match."
)
# TODO: Extend this to support other features such as LoRA.
need_extra_keys
=
bool
(
mm_positions
)
extra_keys
=
None
curr_mm_idx
=
0
ret
=
[]
ret
=
[]
parent_block_hash_value
=
None
parent_block_hash_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
...
@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
...
@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
if
len
(
block_token_ids
)
<
block_size
:
break
break
# Add extra keys if the block is a multi-modal block.
if
need_extra_keys
:
extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
request
,
start
,
end
,
curr_mm_idx
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_token_ids
)
block_token_ids
,
extra_keys
)
ret
.
append
(
block_hash
)
ret
.
append
(
block_hash
)
parent_block_hash_value
=
block_hash
.
hash_value
parent_block_hash_value
=
block_hash
.
hash_value
return
ret
return
ret
vllm/v1/core/scheduler.py
View file @
96ae75ad
...
@@ -516,6 +516,7 @@ class NewRequestData:
...
@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
prompt
:
Optional
[
str
]
prompt
:
Optional
[
str
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_hashes
:
List
[
str
]
mm_positions
:
List
[
"PlaceholderRange"
]
mm_positions
:
List
[
"PlaceholderRange"
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
block_ids
:
List
[
int
]
block_ids
:
List
[
int
]
...
@@ -533,6 +534,7 @@ class NewRequestData:
...
@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt
=
request
.
prompt
,
prompt
=
request
.
prompt
,
mm_inputs
=
request
.
mm_inputs
,
mm_inputs
=
request
.
mm_inputs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
...
...
vllm/v1/engine/async_llm.py
View file @
96ae75ad
...
@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
...
@@ -9,14 +9,13 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.v1.engine.async_stream
import
AsyncStream
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.detokenizer
import
Detokenizer
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.engine.processor
import
Processor
...
@@ -54,15 +53,17 @@ class AsyncLLM(EngineClient):
...
@@ -54,15 +53,17 @@ class AsyncLLM(EngineClient):
lora_config
=
vllm_config
.
lora_config
)
lora_config
=
vllm_config
.
lora_config
)
self
.
tokenizer
.
ping
()
self
.
tokenizer
.
ping
()
# Request streams (map of request_id -> AsyncStream).
# Request streams (map of request_id -> queue).
self
.
request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
rid_to_queue
:
Dict
[
str
,
asyncio
.
Queue
]
=
{}
# List of cancelled request ids to be aborted.
self
.
client_aborted_requests
:
List
[
str
]
=
[]
# Processor (converts Inputs --> EngineCoreRequests).
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
self
.
processor
=
Processor
(
vllm_config
.
lora_config
,
self
.
tokenizer
,
model_config
=
vllm_config
.
model_config
,
input_registry
)
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self
.
detokenizer
=
Detokenizer
(
self
.
detokenizer
=
Detokenizer
(
...
@@ -94,7 +95,7 @@ class AsyncLLM(EngineClient):
...
@@ -94,7 +95,7 @@ class AsyncLLM(EngineClient):
start_engine_loop
:
bool
=
True
,
start_engine_loop
:
bool
=
True
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
)
->
"AsyncLLM
Engine
"
:
)
->
"AsyncLLM"
:
"""Create an AsyncLLM from the EngineArgs."""
"""Create an AsyncLLM from the EngineArgs."""
# Create the engine configs.
# Create the engine configs.
...
@@ -149,14 +150,13 @@ class AsyncLLM(EngineClient):
...
@@ -149,14 +150,13 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
A
sync
Generator
[
Union
[
RequestOutput
,
Pooling
RequestOutput
]
,
None
]
:
)
->
a
sync
io
.
Queue
[
RequestOutput
]:
"""Add new request to the AsyncLLM."""
"""Add new request to the AsyncLLM."""
if
self
.
detokenizer
.
is_request_active
(
request_id
):
# 1) Create a new output queue for the request.
raise
ValueError
(
f
"Request
{
request_id
}
already exists."
)
if
request_id
in
self
.
rid_to_queue
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
# 1) Create a new AsyncStream for the request.
self
.
rid_to_queue
[
request_id
]
=
asyncio
.
Queue
()
stream
=
self
.
_add_request_to_streams
(
request_id
)
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
detokenizer_req
,
engine_core_req
=
self
.
processor
.
process_inputs
(
detokenizer_req
,
engine_core_req
=
self
.
processor
.
process_inputs
(
...
@@ -169,8 +169,10 @@ class AsyncLLM(EngineClient):
...
@@ -169,8 +169,10 @@ class AsyncLLM(EngineClient):
# 4) Add the EngineCoreRequest to EngineCore (separate process).
# 4) Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
engine_core_req
)
await
self
.
engine_core
.
add_request_async
(
engine_core_req
)
# 5) Return the generator.
if
self
.
log_requests
:
return
stream
.
generator
()
logger
.
info
(
"Added request %s."
,
request_id
)
return
self
.
rid_to_queue
[
request_id
]
# TODO: we should support multiple prompts in one call, as you
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
# can do with LLM.generate. So that for multi-prompt completion
...
@@ -190,7 +192,7 @@ class AsyncLLM(EngineClient):
...
@@ -190,7 +192,7 @@ class AsyncLLM(EngineClient):
"""
"""
Main function called by the API server to kick off a request
Main function called by the API server to kick off a request
* 1) Making an AsyncStream corresponding to the Request.
* 1) Making an AsyncStream corresponding to the Request.
#
2) Processing the Input.
*
2) Processing the Input.
* 3) Adding the Request to the Detokenizer.
* 3) Adding the Request to the Detokenizer.
* 4) Adding the Request to the EngineCore (separate process).
* 4) Adding the Request to the EngineCore (separate process).
...
@@ -202,14 +204,15 @@ class AsyncLLM(EngineClient):
...
@@ -202,14 +204,15 @@ class AsyncLLM(EngineClient):
returning the RequestOutput back to the caller.
returning the RequestOutput back to the caller.
"""
"""
# We start the output_handler on the first call to generate() so that
try
:
# we can call __init__ before the event loop starts, which enables us
# We start the output_handler on the first call to generate() so
# to handle startup failure gracefully in the OpenAI server.
# we can call __init__ before the event loop, which enables us
if
self
.
output_handler
is
None
:
# to handle startup failure gracefully in the OpenAI server.
self
.
output_handler
=
asyncio
.
create_task
(
if
self
.
output_handler
is
None
:
self
.
_run_output_handler
())
self
.
output_handler
=
asyncio
.
create_task
(
self
.
_run_output_handler
())
async
for
output
in
await
self
.
add_request
(
q
=
await
self
.
add_request
(
request_id
,
request_id
,
prompt
,
prompt
,
sampling_params
,
sampling_params
,
...
@@ -217,79 +220,42 @@ class AsyncLLM(EngineClient):
...
@@ -217,79 +220,42 @@ class AsyncLLM(EngineClient):
trace_headers
=
trace_headers
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
priority
=
priority
,
):
)
yield
output
def
_finish_stream
(
self
,
request_id
:
str
):
stream
=
self
.
request_streams
.
pop
(
request_id
,
None
)
if
stream
is
not
None
:
stream
.
finish
()
def
_add_request_to_streams
(
self
,
request_id
:
str
,
)
->
AsyncStream
:
if
request_id
in
self
.
request_streams
:
raise
ValueError
(
f
"Request id
{
request_id
}
already running."
)
# Avoid streams having circular ref to parent AsyncLLM object.
aborted_reqs
=
self
.
client_aborted_requests
stream
=
AsyncStream
(
request_id
,
aborted_reqs
.
append
)
self
.
request_streams
[
request_id
]
=
stream
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request_id
)
return
stream
# The output_handler task pushes items into the queue.
# This task pulls from the queue and yields to caller.
async
def
_process_cancellations
(
self
)
->
None
:
while
True
:
"""
# Note: drain queue without await if possible (avoids
Process requests cancelled from user disconnecting.
# task switching under load which helps performance).
out
=
q
.
get_nowait
()
if
q
.
qsize
()
>
0
else
await
q
.
get
()
When a client disconnects, AsyncStream._cancel() is called.
We passed a callback to AsyncStream(), which appends to
# Note: both Detokenizer and EngineCore handle their
self.client_aborted_requests.
# own request cleanup based on finished.
if
out
.
finished
:
As a result, if any requests are canceled from the user side
del
self
.
rid_to_queue
[
request_id
]
the request_id will show up in self.client_aborted_requests.
yield
out
"""
break
# Avoid streams having circular ref to parent AsyncLLM object.
yield
out
if
not
self
.
client_aborted_requests
:
return
# If the request is disconnected by the client, the
reqs_to_abort
=
self
.
client_aborted_requests
.
copy
()
# generate() task will be canceled. So, we abort the
self
.
client_aborted_requests
.
clear
()
# request if we end up here.
except
asyncio
.
CancelledError
:
# Remove from Detokenizer.
await
self
.
abort
(
request_id
)
self
.
detokenizer
.
abort_requests
(
reqs_to_abort
)
raise
# Remove from RequestStreams.
for
request_id
in
reqs_to_abort
:
if
self
.
log_requests
:
logger
.
info
(
"User-cancelled request %s."
,
request_id
)
self
.
_finish_stream
(
request_id
)
# Remove from EngineCore.
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
def
_process_request_outputs
(
self
,
request_outputs
:
List
[
RequestOutput
]):
def
_process_request_outputs
(
self
,
request_outputs
:
List
[
RequestOutput
]):
"""Process outputs by putting them into per-request
AsyncStream
s."""
"""Process outputs by putting them into per-request
queue
s."""
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
request_id
=
request_output
.
request_id
request_id
=
request_output
.
request_id
assert
request_id
in
self
.
request_streams
# Each request in the API server pulls from the per-request stream.
stream
=
self
.
request_streams
.
get
(
request_id
)
if
stream
is
not
None
:
stream
.
put
(
request_output
)
# If finished,
remove from
the tracker.
# Note: it is possible a request was aborted and
remove
d
from
if
request_output
.
finished
:
# the state due to client cancellations, so if we encounter a
if
self
.
log_requests
:
# request id not in the state, we skip.
logger
.
info
(
"Finished request %s."
,
request_id
)
if
request_id
in
self
.
rid_to_queue
:
self
.
_finish_stream
(
request_
id
)
self
.
rid_to_queue
[
request_id
].
put_nowait
(
request_
output
)
async
def
_run_output_handler
(
self
):
async
def
_run_output_handler
(
self
):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
...
@@ -302,24 +268,27 @@ class AsyncLLM(EngineClient):
...
@@ -302,24 +268,27 @@ class AsyncLLM(EngineClient):
# 2) Detokenize based on the output.
# 2) Detokenize based on the output.
request_outputs
,
reqs_to_abort
=
self
.
detokenizer
.
step
(
outputs
)
request_outputs
,
reqs_to_abort
=
self
.
detokenizer
.
step
(
outputs
)
# 3) Put the RequestOutputs into the per-request
AsyncStream
s.
# 3) Put the RequestOutputs into the per-request
queue
s.
self
.
_process_request_outputs
(
request_outputs
)
self
.
_process_request_outputs
(
request_outputs
)
# 4) Abort any requests that finished due to stop strings.
# 4) Abort any requests that finished due to stop strings.
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
await
self
.
engine_core
.
abort_requests_async
(
reqs_to_abort
)
# 5) Abort any requests due to client cancellations.
await
self
.
_process_cancellations
()
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
raise
e
raise
e
# TODO: can we eliminate these?
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
# Note: Who Calls this? I dont think this is actually used.
"""Abort RequestId in self, detokenizer, and engine core."""
raise
ValueError
(
"Not Supported on V1 yet."
)
request_ids
=
[
request_id
]
await
self
.
engine_core
.
abort_requests_async
(
request_ids
)
self
.
detokenizer
.
abort_requests
(
request_ids
)
# If a request finishes while we await then the request_id
# will be removed from the tracked queues before we get here.
if
request_id
in
self
.
rid_to_queue
:
del
self
.
rid_to_queue
[
request_id
]
def
encode
(
def
encode
(
self
,
self
,
...
@@ -382,7 +351,3 @@ class AsyncLLM(EngineClient):
...
@@ -382,7 +351,3 @@ class AsyncLLM(EngineClient):
@
property
@
property
def
dead_error
(
self
)
->
BaseException
:
def
dead_error
(
self
)
->
BaseException
:
return
Exception
()
# TODO: implement
return
Exception
()
# TODO: implement
# Retain V0 name for backwards compatibility.
AsyncLLMEngine
=
AsyncLLM
vllm/v1/engine/async_stream.py
deleted
100644 → 0
View file @
f9f4a735
import
asyncio
from
typing
import
Any
,
AsyncGenerator
,
Callable
,
Optional
,
Type
,
Union
from
vllm.outputs
import
PoolingRequestOutput
,
RequestOutput
class
AsyncStream
:
"""A stream of RequestOutputs or PoolingRequestOutputs for a request
that can be iterated over asynchronously via an async generator."""
STOP_ITERATION
=
Exception
()
# Sentinel
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
self
.
request_id
=
request_id
self
.
_cancel
=
cancel
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
PoolingRequestOutput
,
Exception
])
->
None
:
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
,
)
->
None
:
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
exception
if
self
.
_is_raisable
(
exception
)
else
AsyncStream
.
STOP_ITERATION
)
async
def
generator
(
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
PoolingRequestOutput
],
None
]:
finished
=
False
try
:
while
True
:
result
=
await
self
.
_queue
.
get
()
if
self
.
_is_raisable
(
result
):
finished
=
True
if
result
==
AsyncStream
.
STOP_ITERATION
:
return
raise
result
yield
result
finally
:
self
.
_finished
=
True
if
not
finished
:
self
.
_cancel
(
self
.
request_id
)
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
\
(
isinstance
(
value
,
type
)
and
\
issubclass
(
value
,
BaseException
))
vllm/v1/engine/core.py
View file @
96ae75ad
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
...
@@ -32,7 +32,7 @@ logger = init_logger(__name__)
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_MS
=
5000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
POLLING_TIMEOUT_S
=
POLLING_TIMEOUT_MS
//
1000
LOGGING_TIME_S
=
5000
LOGGING_TIME_S
=
POLLING_TIMEOUT_S
class
EngineCore
:
class
EngineCore
:
...
@@ -65,7 +65,8 @@ class EngineCore:
...
@@ -65,7 +65,8 @@ class EngineCore:
self
.
_last_logging_time
=
time
.
time
()
self
.
_last_logging_time
=
time
.
time
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
(
vllm_config
.
model_config
)
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
self
,
cache_config
:
CacheConfig
)
->
Tuple
[
int
,
int
]:
cache_config
:
CacheConfig
)
->
Tuple
[
int
,
int
]:
...
@@ -98,9 +99,8 @@ class EngineCore:
...
@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
# entry here as well.
assert
request
.
mm_inputs
is
not
None
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
,
request
.
mm_hashes
=
(
request
.
mm_inputs
=
self
.
mm_input_mapper_server
.
process_inputs
(
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
)
request
.
mm_inputs
,
request
.
mm_hashes
))
req
=
Request
.
from_engine_core_request
(
request
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
vllm/v1/engine/llm_engine.py
View file @
96ae75ad
...
@@ -55,9 +55,12 @@ class LLMEngine:
...
@@ -55,9 +55,12 @@ class LLMEngine:
self
.
tokenizer
.
ping
()
self
.
tokenizer
.
ping
()
# Processor (convert Inputs --> EngineCoreRequests)
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
self
.
processor
=
Processor
(
model_config
=
vllm_config
.
model_config
,
vllm_config
.
lora_config
,
self
.
tokenizer
,
cache_config
=
vllm_config
.
cache_config
,
input_registry
,
mm_registry
)
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self
.
detokenizer
=
Detokenizer
(
self
.
detokenizer
=
Detokenizer
(
...
@@ -107,7 +110,10 @@ class LLMEngine:
...
@@ -107,7 +110,10 @@ class LLMEngine:
executor_class
:
Type
[
Executor
]
executor_class
:
Type
[
Executor
]
distributed_executor_backend
=
(
distributed_executor_backend
=
(
vllm_config
.
parallel_config
.
distributed_executor_backend
)
vllm_config
.
parallel_config
.
distributed_executor_backend
)
if
distributed_executor_backend
==
"mp"
:
if
distributed_executor_backend
==
"ray"
:
from
vllm.v1.executor.ray_executor
import
RayExecutor
executor_class
=
RayExecutor
elif
distributed_executor_backend
==
"mp"
:
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
from
vllm.v1.executor.multiproc_executor
import
MultiprocExecutor
executor_class
=
MultiprocExecutor
executor_class
=
MultiprocExecutor
else
:
else
:
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
96ae75ad
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
PIL
import
PIL
from
blake3
import
blake3
from
blake3
import
blake3
...
@@ -8,7 +8,7 @@ from vllm.inputs import PromptType
...
@@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalKwargs
,
MultiModalRegistry
)
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.
v1.
utils
import
LRU
Dict
Cache
from
vllm.utils
import
LRUCache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -42,7 +42,9 @@ class MMInputMapperClient:
...
@@ -42,7 +42,9 @@ class MMInputMapperClient:
model_config
)
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# Init cache
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
# DEBUG: Set to None to disable
self
.
mm_debug_cache_hit_ratio_steps
=
None
self
.
mm_debug_cache_hit_ratio_steps
=
None
...
@@ -61,7 +63,7 @@ class MMInputMapperClient:
...
@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes
:
Optional
[
List
[
str
]],
mm_hashes
:
Optional
[
List
[
str
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
)
->
Tuple
[
List
[
MultiModalKwargs
]
,
Optional
[
List
[
str
]]]
:
)
->
List
[
MultiModalKwargs
]:
if
precomputed_mm_inputs
is
None
:
if
precomputed_mm_inputs
is
None
:
image_inputs
=
mm_data
[
"image"
]
image_inputs
=
mm_data
[
"image"
]
if
not
isinstance
(
image_inputs
,
list
):
if
not
isinstance
(
image_inputs
,
list
):
...
@@ -70,26 +72,21 @@ class MMInputMapperClient:
...
@@ -70,26 +72,21 @@ class MMInputMapperClient:
else
:
else
:
num_inputs
=
len
(
precomputed_mm_inputs
)
num_inputs
=
len
(
precomputed_mm_inputs
)
# Check if hash is enabled
# Sanity
use_hash
=
mm_hashes
is
not
None
if
self
.
use_cache
:
if
use_hash
:
assert
mm_hashes
is
not
None
assert
mm_hashes
is
not
None
assert
num_inputs
==
len
(
assert
num_inputs
==
len
(
mm_hashes
)
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
num_inputs
,
len
(
mm_hashes
))
# Process each image input separately, so that later we can schedule
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes
:
Optional
[
List
[
str
]]
=
[]
if
use_hash
else
None
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
for
input_id
in
range
(
num_inputs
):
for
input_id
in
range
(
num_inputs
):
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
self
.
cache_hit_ratio
(
self
.
mm_debug_cache_hit_ratio_steps
)
self
.
cache_hit_ratio
(
self
.
mm_debug_cache_hit_ratio_steps
)
mm_hash
=
None
mm_input
=
None
mm_input
=
None
if
use_hash
:
if
self
.
use_cache
:
assert
mm_hashes
is
not
None
assert
mm_hashes
is
not
None
mm_hash
=
mm_hashes
[
input_id
]
mm_hash
=
mm_hashes
[
input_id
]
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
...
@@ -106,7 +103,7 @@ class MMInputMapperClient:
...
@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
)
if
use_hash
:
if
self
.
use_cache
:
# Add to cache
# Add to cache
assert
mm_hash
is
not
None
assert
mm_hash
is
not
None
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
...
@@ -114,19 +111,16 @@ class MMInputMapperClient:
...
@@ -114,19 +111,16 @@ class MMInputMapperClient:
self
.
mm_cache_hits
+=
1
self
.
mm_cache_hits
+=
1
mm_input
=
None
# Avoids sending mm_input to Server
mm_input
=
None
# Avoids sending mm_input to Server
if
use_hash
:
assert
mm_hash
is
not
None
assert
ret_hashes
is
not
None
ret_hashes
.
append
(
mm_hash
)
ret_inputs
.
append
(
mm_input
)
ret_inputs
.
append
(
mm_input
)
return
ret_inputs
,
ret_hashes
return
ret_inputs
class
MMInputMapperServer
:
class
MMInputMapperServer
:
def
__init__
(
self
,
):
def
__init__
(
self
,
model_config
):
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
def
process_inputs
(
self
,
self
,
...
@@ -135,6 +129,9 @@ class MMInputMapperServer:
...
@@ -135,6 +129,9 @@ class MMInputMapperServer:
)
->
List
[
MultiModalKwargs
]:
)
->
List
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
return
mm_inputs
full_mm_inputs
=
[]
full_mm_inputs
=
[]
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
assert
mm_hash
is
not
None
...
@@ -154,12 +151,45 @@ class MMHasher:
...
@@ -154,12 +151,45 @@ class MMHasher:
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
hash
(
self
,
prompt
:
PromptType
)
->
Optional
[
List
[
str
]]:
def
hash_dummy_mm_data
(
self
,
mm_data
:
Optional
[
MultiModalDataDict
])
->
Optional
[
List
[
str
]]:
"""Hash user-defined dummy multimodal data used for profiling."""
if
mm_data
is
None
:
return
None
image_inputs
=
mm_data
[
'image'
]
# This is a temporary workaround for models (e.g, Molmo) that
# process multimodal data in the input processor (therefore
# image_inputs is MultiModalKwargs instead of raw input format).
# `raw_mm_data` with the original input format is expected
# in this case.
if
isinstance
(
image_inputs
,
dict
):
assert
"raw_mm_data"
in
image_inputs
and
isinstance
(
image_inputs
[
"raw_mm_data"
],
PIL
.
Image
.
Image
)
image_inputs
=
image_inputs
.
pop
(
"raw_mm_data"
)
return
self
.
hash_images
(
image_inputs
)
def
hash_prompt_mm_data
(
self
,
prompt
:
PromptType
)
->
Optional
[
List
[
str
]]:
"""Hash multimodal data in the user input prompt if they exist."""
if
"multi_modal_data"
not
in
prompt
:
if
"multi_modal_data"
not
in
prompt
:
return
None
return
None
mm_data
=
prompt
[
"multi_modal_data"
]
mm_data
=
prompt
[
"multi_modal_data"
]
if
not
mm_data
:
# mm_data can be None or an empty dict.
return
None
image_inputs
=
mm_data
[
"image"
]
image_inputs
=
mm_data
[
"image"
]
return
self
.
hash_images
(
image_inputs
)
def
hash_images
(
self
,
image_inputs
)
->
Optional
[
List
[
str
]]:
"""Hash PIL image objects to strings."""
if
not
isinstance
(
image_inputs
,
list
):
if
not
isinstance
(
image_inputs
,
list
):
image_inputs
=
[
image_inputs
]
image_inputs
=
[
image_inputs
]
assert
len
(
image_inputs
)
>
0
assert
len
(
image_inputs
)
>
0
...
...
vllm/v1/engine/processor.py
View file @
96ae75ad
import
time
import
time
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
Mapping
,
Optional
,
Tuple
,
Union
from
vllm.config
import
LoRAConfig
,
ModelConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
...
@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
...
@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
DetokenizerRequest
,
EngineCoreRequest
from
vllm.v1.engine
import
DetokenizerRequest
,
EngineCoreRequest
from
vllm.v1.engine.mm_input_mapper
import
MMHasher
,
MMInputMapperClient
from
vllm.v1.engine.mm_input_mapper
import
MMHasher
,
MMInputMapperClient
...
@@ -23,6 +22,7 @@ class Processor:
...
@@ -23,6 +22,7 @@ class Processor:
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
tokenizer
:
BaseTokenizerGroup
,
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
...
@@ -33,8 +33,8 @@ class Processor:
...
@@ -33,8 +33,8 @@ class Processor:
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
generation_config_fields
=
_load
_generation_config
_dict
(
self
.
generation_config_fields
=
model_config
.
try_get
_generation_config
(
model_config
)
)
self
.
input_preprocessor
=
InputPreprocessor
(
model_config
,
self
.
input_preprocessor
=
InputPreprocessor
(
model_config
,
self
.
tokenizer
,
self
.
tokenizer
,
mm_registry
)
mm_registry
)
...
@@ -45,8 +45,9 @@ class Processor:
...
@@ -45,8 +45,9 @@ class Processor:
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
# Multi-modal hasher (for images)
# Multi-modal hasher (for images)
self
.
mm_hasher
=
MMHasher
(
self
.
use_hash
=
(
not
model_config
.
disable_mm_preprocessor_cache
)
or
\
)
if
model_config
.
mm_cache_preprocessor
else
None
cache_config
.
enable_prefix_caching
self
.
mm_hasher
=
MMHasher
()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# This ideally should releases the GIL, so we should not block the
...
@@ -77,8 +78,8 @@ class Processor:
...
@@ -77,8 +78,8 @@ class Processor:
# Compute MM hashes (if enabled)
# Compute MM hashes (if enabled)
mm_hashes
=
None
mm_hashes
=
None
if
self
.
mm
_hash
er
is
not
None
:
if
self
.
use
_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash
(
prompt
)
mm_hashes
=
self
.
mm_hasher
.
hash
_prompt_mm_data
(
prompt
)
# Process inputs.
# Process inputs.
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
...
@@ -118,7 +119,7 @@ class Processor:
...
@@ -118,7 +119,7 @@ class Processor:
# Apply MM mapper
# Apply MM mapper
mm_inputs
=
None
mm_inputs
=
None
if
len
(
decoder_inputs
.
multi_modal_data
)
>
0
:
if
len
(
decoder_inputs
.
multi_modal_data
)
>
0
:
mm_inputs
,
mm_hashes
=
self
.
mm_input_mapper_client
.
process_inputs
(
mm_inputs
=
self
.
mm_input_mapper_client
.
process_inputs
(
decoder_inputs
.
multi_modal_data
,
mm_hashes
,
decoder_inputs
.
multi_modal_data
,
mm_hashes
,
decoder_inputs
.
mm_processor_kwargs
,
precomputed_mm_inputs
)
decoder_inputs
.
mm_processor_kwargs
,
precomputed_mm_inputs
)
...
@@ -179,16 +180,3 @@ class Processor:
...
@@ -179,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def
_load_generation_config_dict
(
model_config
:
ModelConfig
)
->
Dict
[
str
,
Any
]:
config
=
try_get_generation_config
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
,
revision
=
model_config
.
revision
,
)
if
config
is
None
:
return
{}
return
config
.
to_diff_dict
()
vllm/v1/executor/ray_executor.py
0 → 100644
View file @
96ae75ad
import
os
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.ray_utils
import
(
RayWorkerWrapper
,
initialize_ray_cluster
,
ray
)
from
vllm.v1.outputs
import
ModelRunnerOutput
if
ray
is
not
None
:
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
logger
=
init_logger
(
__name__
)
class
RayExecutor
(
Executor
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
)
->
None
:
self
.
vllm_config
=
vllm_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
model_config
=
vllm_config
.
model_config
self
.
forward_dag
:
Optional
[
ray
.
dag
.
CompiledDAG
]
=
None
# Disable Ray usage stats collection.
ray_usage
=
os
.
environ
.
get
(
"RAY_USAGE_STATS_ENABLED"
,
"0"
)
if
ray_usage
!=
"1"
:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
initialize_ray_cluster
(
self
.
parallel_config
)
placement_group
=
self
.
parallel_config
.
placement_group
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# A list of workers to run a model.
self
.
workers
:
List
[
RayWorkerWrapper
]
=
[]
if
self
.
parallel_config
.
ray_workers_use_nsight
:
ray_remote_kwargs
=
self
.
_configure_ray_workers_use_nsight
(
ray_remote_kwargs
)
# Create the workers.
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
# Skip bundles that don't have GPUs,
# as each worker needs one GPU.
continue
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
1
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
)
self
.
workers
.
append
(
worker
)
logger
.
debug
(
"workers: %s"
,
self
.
workers
)
worker_ips
=
[
ray
.
get
(
worker
.
get_node_ip
.
remote
())
# type: ignore[attr-defined]
for
worker
in
self
.
workers
]
ip_counts
:
Dict
[
str
,
int
]
=
{}
for
ip
in
worker_ips
:
ip_counts
[
ip
]
=
ip_counts
.
get
(
ip
,
0
)
+
1
worker_to_ip
=
dict
(
zip
(
self
.
workers
,
worker_ips
))
def
sort_by_driver_then_worker_ip
(
worker
):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first. This is simply a tiebreaker to make
sure the workers are sorted in a deterministic way.
"""
ip
=
worker_to_ip
[
worker
]
return
(
ip
!=
driver_ip
,
ip_counts
[
ip
],
ip
)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self
.
workers
=
sorted
(
self
.
workers
,
key
=
sort_by_driver_then_worker_ip
)
# Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids
=
self
.
_run_workers
(
"get_node_and_gpu_ids"
)
node_workers
=
defaultdict
(
list
)
# node id -> list of worker ranks
node_gpus
=
defaultdict
(
list
)
# node id -> list of gpu ids
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
):
node_workers
[
node_id
].
append
(
i
)
# `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
# string sorting is not sufficient.
# see https://github.com/vllm-project/vllm/issues/5590
gpu_ids
=
[
int
(
x
)
for
x
in
gpu_ids
]
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
all_ips
=
set
(
worker_ips
)
n_ips
=
len
(
all_ips
)
n_nodes
=
len
(
node_workers
)
if
n_nodes
!=
n_ips
:
raise
RuntimeError
(
f
"Every node should have a unique IP address. Got
{
n_nodes
}
"
f
" nodes with node ids
{
list
(
node_workers
.
keys
())
}
and "
f
"
{
n_ips
}
unique IP addresses
{
all_ips
}
. Please check your"
" network configuration. If you set `VLLM_HOST_IP` or "
"`HOST_IP` environment variable, make sure it is unique for"
" each node."
)
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables
=
[({
"CUDA_VISIBLE_DEVICES"
:
","
.
join
(
map
(
str
,
node_gpus
[
node_id
])),
"VLLM_TRACE_FUNCTION"
:
str
(
envs
.
VLLM_TRACE_FUNCTION
),
"VLLM_USE_V1"
:
str
(
int
(
envs
.
VLLM_USE_V1
)),
**
({
"VLLM_ATTENTION_BACKEND"
:
envs
.
VLLM_ATTENTION_BACKEND
}
if
envs
.
VLLM_ATTENTION_BACKEND
is
not
None
else
{})
},
)
for
(
node_id
,
_
)
in
worker_node_and_gpu_ids
]
self
.
_env_vars_for_all_workers
=
(
all_args_to_update_environment_variables
)
self
.
_run_workers
(
"update_environment_variables"
,
all_args
=
self
.
_get_env_vars_to_be_updated
())
if
len
(
node_gpus
)
==
1
:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip
=
"127.0.0.1"
distributed_init_method
=
get_distributed_init_method
(
driver_ip
,
get_open_port
())
# Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs
=
[
self
.
_get_worker_kwargs
(
local_rank
=
node_workers
[
node_id
].
index
(
rank
),
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
for
rank
,
(
node_id
,
_
)
in
enumerate
(
worker_node_and_gpu_ids
)
]
self
.
_run_workers
(
"init_worker"
,
all_kwargs
=
init_worker_all_kwargs
)
self
.
_run_workers
(
"initialize"
)
self
.
_run_workers
(
"load_model"
)
def
_configure_ray_workers_use_nsight
(
self
,
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env
=
ray_remote_kwargs
.
setdefault
(
"runtime_env"
,
{})
runtime_env
.
update
({
"nsight"
:
{
"t"
:
"cuda,cudnn,cublas"
,
"o"
:
"'worker_process_%p'"
,
"cuda-graph-trace"
:
"node"
,
}
})
return
ray_remote_kwargs
def
_get_env_vars_to_be_updated
(
self
):
return
self
.
_env_vars_for_all_workers
def
_get_worker_kwargs
(
self
,
local_rank
:
int
=
0
,
rank
:
int
=
0
,
distributed_init_method
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Any
]:
"""
Return worker init args for a given rank.
"""
if
distributed_init_method
is
None
:
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
return
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""
Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes
the min of the results, guaranteeing that the selected cache sizes are
compatible with all workers.
Returns:
- tuple[num_gpu_blocks, num_cpu_blocks]
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"determine_num_available_blocks"
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks
=
min
(
b
[
0
]
for
b
in
num_blocks
)
num_cpu_blocks
=
min
(
b
[
1
]
for
b
in
num_blocks
)
return
num_gpu_blocks
,
num_cpu_blocks
def
initialize
(
self
,
num_gpu_blocks
:
int
)
->
None
:
"""
Initialize the KV cache in all workers.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger
.
info
(
"# GPU blocks: %d"
,
num_gpu_blocks
)
self
.
_run_workers
(
"initialize_cache"
,
num_gpu_blocks
)
self
.
_run_workers
(
"compile_or_warm_up_model"
)
def
_run_workers
(
self
,
method
:
str
,
*
args
,
all_args
:
Optional
[
List
[
Tuple
[
Any
,
...]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
**
kwargs
,
)
->
Any
:
"""
Runs the given method on all workers. Can be used in the following
ways:
Args:
- args/kwargs: All workers share the same args/kwargs
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
count
=
len
(
self
.
workers
)
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
0
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
0
,
None
)
ray_worker_refs
=
[
worker
.
execute_method
.
remote
(
# type: ignore[attr-defined]
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
return
ray
.
get
(
ray_worker_refs
)
def
execute_model
(
self
,
scheduler_output
,
)
->
ModelRunnerOutput
:
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
# Only the first worker (with rank 0) returns the execution result.
# Others return None.
output
=
ray
.
get
(
self
.
forward_dag
.
execute
(
scheduler_output
))[
0
]
return
output
def
profile
(
self
,
is_start
=
True
):
raise
NotImplementedError
def
shutdown
(
self
):
if
hasattr
(
self
,
"forward_dag"
)
and
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
self
.
forward_dag
=
None
def
check_health
(
self
)
->
None
:
logger
.
debug
(
"Called check_health."
)
def
_check_ray_compiled_graph_installation
(
self
):
import
pkg_resources
from
packaging
import
version
required_version
=
version
.
parse
(
"2.39"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
import
importlib.util
raycg
=
importlib
.
util
.
find_spec
(
"ray.experimental.compiled_dag_ref"
)
if
raycg
is
None
:
raise
ValueError
(
"Ray Compiled Graph is not installed. "
"Run `pip install ray[adag]` to install it."
)
cupy_spec
=
importlib
.
util
.
find_spec
(
"cupy"
)
if
cupy_spec
is
None
and
envs
.
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL
:
raise
ValueError
(
"cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set."
"Run `pip install ray[adag]` and check cupy installation."
)
def
_compiled_ray_dag
(
self
):
assert
self
.
parallel_config
.
use_ray
self
.
_check_ray_compiled_graph_installation
()
from
ray.dag
import
InputNode
,
MultiOutputNode
with
InputNode
()
as
input_batches
:
outputs
=
[
worker
.
execute_model
.
bind
(
# type: ignore[attr-defined]
input_batches
)
for
worker
in
self
.
workers
]
forward_dag
=
MultiOutputNode
(
outputs
)
return
forward_dag
.
experimental_compile
()
def
__del__
(
self
):
self
.
shutdown
()
vllm/v1/executor/ray_utils.py
0 → 100644
View file @
96ae75ad
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
get_ip
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.worker.worker_base
import
WorkerWrapperBase
if
TYPE_CHECKING
:
from
vllm.v1.core.scheduler
import
SchedulerOutput
logger
=
init_logger
(
__name__
)
PG_WAIT_TIMEOUT
=
60
try
:
import
ray
from
ray.util
import
placement_group_table
from
ray.util.placement_group
import
PlacementGroup
try
:
from
ray._private.state
import
available_resources_per_node
except
ImportError
:
# Ray 2.9.x doesn't expose `available_resources_per_node`
from
ray._private.state
import
state
as
_state
available_resources_per_node
=
_state
.
_available_resources_per_node
class
RayWorkerWrapper
(
WorkerWrapperBase
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on
# that thread. It will be removed soon.
self
.
compiled_dag_cuda_device_set
=
False
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
def
get_node_and_gpu_ids
(
self
)
->
Tuple
[
str
,
List
[
int
]]:
node_id
=
ray
.
get_runtime_context
().
get_node_id
()
gpu_ids
=
ray
.
get_gpu_ids
()
return
node_id
,
gpu_ids
def
setup_device_if_necessary
(
self
):
# TODO(swang): This is needed right now because Ray CG executes
# on a background thread, so we need to reset torch's current
# device.
# We can remove this API after it is fixed in compiled graph.
import
torch
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
ModelRunnerOutput
:
self
.
setup_device_if_necessary
()
assert
self
.
worker
is
not
None
,
"Worker is not initialized"
output
=
self
.
worker
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
ray_import_err
=
None
except
ImportError
as
e
:
ray
=
None
# type: ignore
ray_import_err
=
e
RayWorkerWrapper
=
None
# type: ignore
def
ray_is_available
()
->
bool
:
"""Returns True if Ray is available."""
return
ray
is
not
None
def
assert_ray_available
():
"""
Raise an exception if Ray is not available.
"""
if
ray
is
None
:
raise
ValueError
(
"Failed to import Ray, please install Ray with "
"`pip install ray`."
)
from
ray_import_err
def
_verify_bundles
(
placement_group
:
"PlacementGroup"
,
parallel_config
:
ParallelConfig
,
device_str
:
str
):
"""
Verify a given placement group has bundles located in the right place.
There are 2 rules.
- Warn if all tensor parallel workers cannot fit in a single node.
- Fail if driver node is not included in a placement group.
Args:
placement_group: The placement group to verify.
parallel_config: The parallel configuration.
device_str: The required device.
"""
assert
ray
.
is_initialized
(),
(
"Ray is not initialized although distributed-executor-backend is ray."
)
pg_data
=
placement_group_table
(
placement_group
)
# bundle_idx -> node_id
bundle_to_node_ids
=
pg_data
[
"bundles_to_node_id"
]
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles
=
pg_data
[
"bundles"
]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle
:
Dict
[
str
,
List
[
Dict
[
str
,
float
]]]
=
defaultdict
(
list
)
for
bundle_idx
,
node_id
in
bundle_to_node_ids
.
items
():
node_id_to_bundle
[
node_id
].
append
(
bundles
[
bundle_idx
])
driver_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
if
driver_node_id
not
in
node_id_to_bundle
:
raise
RuntimeError
(
f
"driver node id
{
driver_node_id
}
is not included in a placement "
f
"group
{
placement_group
.
id
}
. Node id -> bundles "
f
"
{
node_id_to_bundle
}
. "
"You don't have enough GPUs available in a current node. Check "
"`ray status` to see if you have available GPUs in a node "
f
"
{
driver_node_id
}
before starting an vLLM engine."
)
for
node_id
,
bundles
in
node_id_to_bundle
.
items
():
if
len
(
bundles
)
<
parallel_config
.
tensor_parallel_size
:
logger
.
warning
(
"tensor_parallel_size=%d "
"is bigger than a reserved number of %ss (%d "
"%ss) in a node %s. Tensor parallel workers can be "
"spread out to 2+ nodes which can degrade the performance "
"unless you have fast interconnect across nodes, like "
"Infiniband. To resolve this issue, make sure you have more "
"than %d GPUs available at each node."
,
parallel_config
.
tensor_parallel_size
,
device_str
,
len
(
bundles
),
device_str
,
node_id
,
parallel_config
.
tensor_parallel_size
)
def
_wait_until_pg_ready
(
current_placement_group
:
"PlacementGroup"
):
"""Wait until a placement group is ready.
It prints the informative log messages if the placement group is
not created within time.
"""
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
placement_group_specs
=
current_placement_group
.
bundle_specs
s
=
time
.
time
()
pg_ready_ref
=
current_placement_group
.
ready
()
wait_interval
=
10
while
time
.
time
()
-
s
<
PG_WAIT_TIMEOUT
:
ready
,
_
=
ray
.
wait
([
pg_ready_ref
],
timeout
=
wait_interval
)
if
len
(
ready
)
>
0
:
break
# Exponential backoff for warning print.
wait_interval
*=
2
logger
.
info
(
"Waiting for creating a placement group of specs for "
"%d seconds. specs=%s. Check "
"`ray status` to see if you have enough resources."
,
int
(
time
.
time
()
-
s
),
placement_group_specs
)
try
:
ray
.
get
(
pg_ready_ref
,
timeout
=
0
)
except
ray
.
exceptions
.
GetTimeoutError
:
raise
ValueError
(
"Cannot provide a placement group of "
f
"
{
placement_group_specs
=
}
within
{
PG_WAIT_TIMEOUT
}
seconds. See "
"`ray status` to make sure the cluster has enough resources."
)
from
None
def
initialize_ray_cluster
(
parallel_config
:
ParallelConfig
,
ray_address
:
Optional
[
str
]
=
None
,
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
"""
assert_ray_available
()
# Connect to a ray cluster.
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# Try to connect existing ray instance and create a new one if not found
try
:
ray
.
init
(
"auto"
)
except
ConnectionError
:
logger
.
warning
(
"No existing RAY instance detected. "
"A new instance will be launched with current node resources."
)
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
,
num_gpus
=
parallel_config
.
world_size
)
else
:
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
parallel_config
.
placement_group
:
# Placement group is already set.
return
device_str
=
"GPU"
if
not
current_platform
.
is_tpu
()
else
"TPU"
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
# We are in a placement group
bundles
=
current_placement_group
.
bundle_specs
# Verify that we can use the placement group.
device_bundles
=
0
for
bundle
in
bundles
:
bundle_devices
=
bundle
.
get
(
device_str
,
0
)
if
bundle_devices
>
1
:
raise
ValueError
(
"Placement group bundle cannot have more than 1 "
f
"
{
device_str
}
."
)
if
bundle_devices
:
device_bundles
+=
1
if
parallel_config
.
world_size
>
device_bundles
:
raise
ValueError
(
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
f
"Required number of devices:
{
parallel_config
.
world_size
}
. "
f
"Total number of devices:
{
device_bundles
}
."
)
else
:
num_devices_in_cluster
=
ray
.
cluster_resources
().
get
(
device_str
,
0
)
if
parallel_config
.
world_size
>
num_devices_in_cluster
:
raise
ValueError
(
f
"The number of required
{
device_str
}
s exceeds the total "
f
"number of available
{
device_str
}
s in the placement group."
)
# Create a new placement group
placement_group_specs
:
List
[
Dict
[
str
,
float
]]
=
([{
device_str
:
1.0
}
for
_
in
range
(
parallel_config
.
world_size
)])
# vLLM engine is also a worker to execute model with an accelerator,
# so it requires to have the device in a current node. Check if
# the current node has at least one device.
current_ip
=
get_ip
()
current_node_id
=
ray
.
get_runtime_context
().
get_node_id
()
current_node_resource
=
available_resources_per_node
()[
current_node_id
]
if
current_node_resource
.
get
(
device_str
,
0
)
<
1
:
raise
ValueError
(
f
"Current node has no
{
device_str
}
available. "
f
"
{
current_node_resource
=
}
. vLLM engine cannot start without "
f
"
{
device_str
}
. Make sure you have at least 1
{
device_str
}
"
f
"available in a node
{
current_node_id
=
}
{
current_ip
=
}
."
)
# This way, at least bundle is required to be created in a current
# node.
placement_group_specs
[
0
][
f
"node:
{
current_ip
}
"
]
=
0.001
# By default, Ray packs resources as much as possible.
current_placement_group
=
ray
.
util
.
placement_group
(
placement_group_specs
,
strategy
=
"PACK"
)
_wait_until_pg_ready
(
current_placement_group
)
assert
current_placement_group
is
not
None
_verify_bundles
(
current_placement_group
,
parallel_config
,
device_str
)
# Set the placement group in the parallel config
parallel_config
.
placement_group
=
current_placement_group
vllm/v1/request.py
View file @
96ae75ad
import
enum
import
enum
from
typing
import
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
vllm.inputs
import
DecoderOnlyInputs
,
SingletonInputsAdapter
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
SingletonInputsAdapter
,
token_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
...
@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.utils
import
ConstantList
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
class
Request
:
class
Request
:
...
@@ -45,6 +48,7 @@ class Request:
...
@@ -45,6 +48,7 @@ class Request:
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
num_computed_tokens
=
0
self
.
num_computed_tokens
=
0
# Multi-modal input metadata.
mm_positions
=
self
.
inputs
.
multi_modal_placeholders
mm_positions
=
self
.
inputs
.
multi_modal_placeholders
if
mm_positions
:
if
mm_positions
:
# FIXME(woosuk): Support other modalities.
# FIXME(woosuk): Support other modalities.
...
@@ -56,6 +60,12 @@ class Request:
...
@@ -56,6 +60,12 @@ class Request:
if
self
.
inputs
.
multi_modal_inputs
:
if
self
.
inputs
.
multi_modal_inputs
:
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_hashes
:
List
[
str
]
=
self
.
inputs
.
multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self
.
_kv_block_hashes
:
List
[
BlockHashType
]
=
[]
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
return
cls
(
return
cls
(
...
@@ -65,6 +75,7 @@ class Request:
...
@@ -65,6 +75,7 @@ class Request:
prompt
=
request
.
prompt
,
prompt
=
request
.
prompt
,
multi_modal_data
=
None
,
multi_modal_data
=
None
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
mm_processor_kwargs
=
None
,
mm_processor_kwargs
=
None
,
),
),
...
@@ -121,6 +132,17 @@ class Request:
...
@@ -121,6 +132,17 @@ class Request:
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
return
num_tokens
return
num_tokens
@
property
def
kv_block_hashes
(
self
)
->
ConstantList
[
"BlockHashType"
]:
# Prevent directly appending to the kv_block_hashes.
return
ConstantList
(
self
.
_kv_block_hashes
)
def
set_kv_block_hashes
(
self
,
value
:
List
[
"BlockHashType"
])
->
None
:
self
.
_kv_block_hashes
=
value
def
append_kv_block_hashes
(
self
,
block_hash
:
"BlockHashType"
)
->
None
:
self
.
_kv_block_hashes
.
append
(
block_hash
)
class
RequestStatus
(
enum
.
IntEnum
):
class
RequestStatus
(
enum
.
IntEnum
):
"""Status of a request."""
"""Status of a request."""
...
...
vllm/v1/sample/metadata.py
View file @
96ae75ad
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
from
typing
import
Dict
,
List
,
Optional
,
Set
import
torch
import
torch
...
@@ -19,3 +19,13 @@ class SamplingMetadata:
...
@@ -19,3 +19,13 @@ class SamplingMetadata:
generators
:
Dict
[
int
,
torch
.
Generator
]
generators
:
Dict
[
int
,
torch
.
Generator
]
max_num_logprobs
:
int
max_num_logprobs
:
int
no_penalties
:
bool
prompt_token_ids
:
Optional
[
torch
.
Tensor
]
frequency_penalties
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
output_token_ids
:
List
[
List
[
int
]]
min_tokens
:
List
[
int
]
stop_token_ids
:
List
[
Set
[
int
]]
vllm/v1/sample/ops/__init__.py
0 → 100644
View file @
96ae75ad
Prev
1
…
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment