Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
897 additions
and
338 deletions
+897
-338
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+23
-7
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+27
-17
vllm/core/evictor_v1.py
vllm/core/evictor_v1.py
+0
-0
vllm/core/evictor_v2.py
vllm/core/evictor_v2.py
+127
-0
vllm/core/interfaces.py
vllm/core/interfaces.py
+1
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+75
-25
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+48
-22
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+8
-8
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+19
-21
vllm/distributed/device_communicators/pynccl_utils.py
vllm/distributed/device_communicators/pynccl_utils.py
+2
-2
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+31
-15
vllm/distributed/utils.py
vllm/distributed/utils.py
+7
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+76
-16
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+40
-23
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+198
-100
vllm/engine/metrics.py
vllm/engine/metrics.py
+171
-61
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+6
-0
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+20
-5
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+14
-8
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+4
-3
No files found.
vllm/core/block_manager_v1.py
View file @
1591c68f
"""A block manager that manages token blocks."""
"""A block manager that manages token blocks."""
import
math
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
itertools
import
count
,
takewhile
from
itertools
import
count
,
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
...
@@ -7,7 +8,7 @@ from typing import Sequence as GenericSequence
...
@@ -7,7 +8,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Set
from
typing
import
Set
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor
_v1
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
@@ -46,6 +47,10 @@ class BlockAllocatorBase(ABC):
...
@@ -46,6 +47,10 @@ class BlockAllocatorBase(ABC):
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
)
->
int
:
pass
pass
@
abstractmethod
def
get_num_total_blocks
(
self
)
->
int
:
pass
@
abstractmethod
@
abstractmethod
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
pass
pass
...
@@ -130,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
...
@@ -130,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
return
(
self
.
num_blocks
-
self
.
current_num_blocks
+
return
(
self
.
num_blocks
-
self
.
current_num_blocks
+
self
.
evictor
.
num_blocks
)
self
.
evictor
.
num_blocks
)
def
get_num_total_blocks
(
self
)
->
int
:
return
self
.
num_blocks
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
return
block_hash
in
self
.
cached_blocks
or
block_hash
in
self
.
evictor
return
block_hash
in
self
.
cached_blocks
or
block_hash
in
self
.
evictor
...
@@ -189,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
...
@@ -189,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
def
get_num_free_blocks
(
self
)
->
int
:
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
free_blocks
)
return
len
(
self
.
free_blocks
)
def
get_num_total_blocks
(
self
)
->
int
:
return
self
.
num_blocks
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
def
contains_block
(
self
,
block_hash
:
int
)
->
bool
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Invalid codepath for uncached block allocator."
)
"Invalid codepath for uncached block allocator."
)
...
@@ -220,9 +231,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -220,9 +231,9 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self
.
block_sliding_window
=
None
self
.
block_sliding_window
=
None
if
sliding_window
is
not
None
:
if
sliding_window
is
not
None
:
assert
sliding_window
%
block
_
size
==
0
,
(
sliding
_
window
,
# Round up to nearest
block
size
to regularize
sliding
window
block_
size
)
# allocation
size
s.
self
.
block_sliding_window
=
sliding_window
/
/
block_size
self
.
block_sliding_window
=
math
.
ceil
(
sliding_window
/
block_size
)
self
.
watermark
=
watermark
self
.
watermark
=
watermark
assert
watermark
>=
0.0
assert
watermark
>=
0.0
...
@@ -390,7 +401,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -390,7 +401,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table
.
append
(
block_table
[
len
(
block_table
)
%
block_table
.
append
(
block_table
[
len
(
block_table
)
%
self
.
block_sliding_window
])
self
.
block_sliding_window
])
else
:
else
:
# The sequence has a new logical block.
# The sequence has
h
a new logical block.
# Allocate a new physical block.
# Allocate a new physical block.
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
block_table
.
append
(
new_block
)
block_table
.
append
(
new_block
)
...
@@ -443,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -443,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
can_swap_in
(
self
,
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
num_lookahead_slots
:
int
=
0
)
->
AllocStatus
:
assert
(
num_lookahead_slots
==
0
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
),
"BlockSpaceManagerV1 does not support lookahead allocation"
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
...
@@ -453,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -453,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# at least one free block right after the swap-in.
# at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot().
# NOTE: This should match the logic in can_append_slot().
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
if
self
.
gpu_allocator
.
get_num_total_blocks
()
<
num_required_blocks
:
return
AllocStatus
.
NEVER
elif
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
:
return
AllocStatus
.
OK
else
:
return
AllocStatus
.
LATER
def
swap_in
(
self
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
...
...
vllm/core/block_manager_v2.py
View file @
1591c68f
...
@@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
watermark
=
watermark
self
.
watermark
=
watermark
assert
watermark
>=
0.0
assert
watermark
>=
0.0
assert
not
enable_caching
,
"Prefix caching not yet supported"
self
.
enable_caching
=
enable_caching
self
.
enable_caching
=
enable_caching
self
.
watermark_blocks
=
int
(
watermark
*
num_gpu_blocks
)
self
.
watermark_blocks
=
int
(
watermark
*
num_gpu_blocks
)
self
.
block_allocator
=
CpuGpuBlockAllocator
.
create
(
self
.
block_allocator
=
CpuGpuBlockAllocator
.
create
(
# Currently, only naive blocks are supported (no prefix caching).
allocator_type
=
"prefix_caching"
if
enable_caching
else
"naive"
,
allocator_type
=
"naive"
,
num_gpu_blocks
=
num_gpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
block_size
=
block_size
,
...
@@ -192,19 +190,30 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -192,19 +190,30 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert
seq
.
seq_id
in
self
.
block_tables
assert
seq
.
seq_id
in
self
.
block_tables
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
assert
all
(
b
is
not
None
for
b
in
block_ids
)
assert
all
(
b
is
not
None
for
b
in
block_ids
)
return
block_ids
return
block_ids
# type: ignore
def
access_all_blocks_in_seq
(
self
,
seq
,
now
):
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
# TODO add prefix caching support.
# Update the last accessed time of all the blocks accessed
# Tracked here https://github.com/vllm-project/vllm/issues/3667
# in this step.
pass
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if
self
.
enable_caching
:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_ids
=
[]
for
block_id
in
block_table
.
physical_block_ids
:
block_ids
.
append
(
block_id
)
self
.
block_allocator
.
mark_blocks_as_accessed
(
block_ids
,
# type: ignore
now
)
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
#
We ignore the sequence group as its not necessary. After the batch is
#
The only need for mark block as computed is for prefix caching,
#
formed by the scheduler, we do not need to mark blocks from individual
#
while currently we could determine whether one block is computed
#
sequence groups as computed -- all blocks in the batch can be marked
#
or not by check whether it has content hash.
#
as computed
.
#
So this function is useless for block_v2
.
self
.
block_allocator
.
mark_blocks_as_computed
()
pass
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seqs
:
List
[
Sequence
])
->
GenericSequence
[
int
]:
self
,
seqs
:
List
[
Sequence
])
->
GenericSequence
[
int
]:
...
@@ -220,16 +229,17 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -220,16 +229,17 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq_block_ids
=
[
seq_block_ids
=
[
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
for
seq
in
seqs
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
for
seq
in
seqs
]
]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return
self
.
block_allocator
.
get_common_computed_block_ids
(
return
self
.
block_allocator
.
get_common_computed_block_ids
(
seq_block_ids
)
seq_block_ids
)
# type: ignore
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
num_lookahead_slots
:
int
)
->
AllocStatus
:
return
False
return
AllocStatus
.
LATER
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
...
...
vllm/core/evictor.py
→
vllm/core/evictor
_v1
.py
View file @
1591c68f
File moved
vllm/core/evictor_v2.py
0 → 100644
View file @
1591c68f
import
enum
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
OrderedDict
,
Tuple
class
EvictionPolicy
(
enum
.
Enum
):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU
=
enum
.
auto
()
class
Evictor
(
ABC
):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@
abstractmethod
def
__init__
(
self
):
pass
@
abstractmethod
def
__contains__
(
self
,
block_id
:
int
)
->
bool
:
pass
@
abstractmethod
def
evict
(
self
)
->
Tuple
[
int
,
int
]:
"""Runs the eviction algorithm and returns the evicted block's
content hash along with physical block id along with physical block id
"""
pass
@
abstractmethod
def
add
(
self
,
block_id
:
int
,
content_hash
:
int
,
num_hashed_tokens
:
int
,
last_accessed
:
float
):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@
abstractmethod
def
update
(
self
,
block_id
:
int
,
last_accessed
:
float
):
"""Update corresponding block's access time in metadata"""
pass
@
abstractmethod
def
remove
(
self
,
block_id
:
int
):
"""Remove a given block id from the cache."""
pass
@
abstractproperty
def
num_blocks
(
self
)
->
int
:
pass
class
BlockMetaData
():
"""Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction
Here we use physical block id as the dict key, as there maybe several
blocks with the same content hash, but their physical id is unique.
"""
def
__init__
(
self
,
content_hash
:
int
,
num_hashed_tokens
:
int
,
last_accessed
:
float
):
self
.
content_hash
=
content_hash
self
.
num_hashed_tokens
=
num_hashed_tokens
self
.
last_accessed
=
last_accessed
class
LRUEvictor
(
Evictor
):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def
__init__
(
self
):
self
.
free_table
:
OrderedDict
[
int
,
BlockMetaData
]
=
OrderedDict
()
def
__contains__
(
self
,
block_id
:
int
)
->
bool
:
return
block_id
in
self
.
free_table
def
evict
(
self
)
->
Tuple
[
int
,
int
]:
if
len
(
self
.
free_table
)
==
0
:
raise
ValueError
(
"No usable cache memory left"
)
evicted_block
=
next
(
iter
(
self
.
free_table
.
values
()))
evicted_block_id
=
next
(
iter
(
self
.
free_table
.
keys
()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for
_id
,
block
in
self
.
free_table
.
items
():
if
evicted_block
.
last_accessed
>
block
.
last_accessed
or
(
evicted_block
.
last_accessed
==
block
.
last_accessed
and
evicted_block
.
num_hashed_tokens
<
block
.
num_hashed_tokens
):
evicted_block
=
block
evicted_block_id
=
_id
self
.
free_table
.
pop
(
evicted_block_id
)
return
evicted_block_id
,
evicted_block
.
content_hash
def
add
(
self
,
block_id
:
int
,
content_hash
:
int
,
num_hashed_tokens
:
int
,
last_accessed
:
float
):
self
.
free_table
[
block_id
]
=
BlockMetaData
(
content_hash
,
num_hashed_tokens
,
last_accessed
)
def
update
(
self
,
block_id
:
int
,
last_accessed
:
float
):
self
.
free_table
[
block_id
].
last_accessed
=
last_accessed
def
remove
(
self
,
block_id
:
int
):
if
block_id
not
in
self
.
free_table
:
raise
ValueError
(
"Attempting to remove block that's not in the evictor"
)
self
.
free_table
.
pop
(
block_id
)
@
property
def
num_blocks
(
self
)
->
int
:
return
len
(
self
.
free_table
)
def
make_evictor
(
eviction_policy
:
EvictionPolicy
)
->
Evictor
:
if
eviction_policy
==
EvictionPolicy
.
LRU
:
return
LRUEvictor
()
else
:
raise
ValueError
(
f
"Unknown cache eviction policy:
{
eviction_policy
}
"
)
vllm/core/interfaces.py
View file @
1591c68f
...
@@ -63,7 +63,7 @@ class BlockSpaceManager(ABC):
...
@@ -63,7 +63,7 @@ class BlockSpaceManager(ABC):
@
abstractmethod
@
abstractmethod
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
num_lookahead_slots
:
int
)
->
AllocStatus
:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
1591c68f
import
enum
import
enum
import
os
import
random
import
time
import
time
from
collections
import
deque
from
collections
import
deque
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
...
@@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
...
@@ -15,6 +17,13 @@ from vllm.utils import merge_dicts
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Test-only. If configured, decode is preempted with
# ARTIFICIAL_PREEMPTION_PROB% probability.
ENABLE_ARTIFICIAL_PREEMPT
=
bool
(
os
.
getenv
(
"VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT"
,
False
))
# noqa
ARTIFICIAL_PREEMPTION_PROB
=
0.5
ARTIFICIAL_PREEMPTION_MAX_CNT
=
500
class
PreemptionMode
(
enum
.
Enum
):
class
PreemptionMode
(
enum
.
Enum
):
"""Preemption modes.
"""Preemption modes.
...
@@ -119,6 +128,8 @@ class SchedulerOutputs:
...
@@ -119,6 +128,8 @@ class SchedulerOutputs:
ignored_seq_groups
:
List
[
SequenceGroup
]
ignored_seq_groups
:
List
[
SequenceGroup
]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# The number of requests in the running queue
running_queue_size
:
int
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
...
@@ -201,6 +212,8 @@ class SchedulerSwappedInOutputs:
...
@@ -201,6 +212,8 @@ class SchedulerSwappedInOutputs:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
blocks_to_copy
:
Dict
[
int
,
List
[
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
num_lookahead_slots
:
int
num_lookahead_slots
:
int
# Infeasible sequence groups.
infeasible_seq_groups
:
List
[
SequenceGroup
]
@
classmethod
@
classmethod
def
create_empty
(
cls
)
->
"SchedulerSwappedInOutputs"
:
def
create_empty
(
cls
)
->
"SchedulerSwappedInOutputs"
:
...
@@ -210,6 +223,7 @@ class SchedulerSwappedInOutputs:
...
@@ -210,6 +223,7 @@ class SchedulerSwappedInOutputs:
blocks_to_swap_in
=
{},
blocks_to_swap_in
=
{},
blocks_to_copy
=
{},
blocks_to_copy
=
{},
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
infeasible_seq_groups
=
[],
)
)
...
@@ -286,6 +300,13 @@ class Scheduler:
...
@@ -286,6 +300,13 @@ class Scheduler:
# Latency of the last prompt step
# Latency of the last prompt step
self
.
last_prompt_latency
=
0.0
self
.
last_prompt_latency
=
0.0
# The following field is test-only. It is used to inject artificial
# preemption.
self
.
enable_artificial_preemption
=
ENABLE_ARTIFICIAL_PREEMPT
self
.
artificial_preempt_cnt
=
(
ARTIFICIAL_PREEMPTION_MAX_CNT
if
self
.
enable_artificial_preemption
else
0
)
@
property
@
property
def
lora_enabled
(
self
)
->
bool
:
def
lora_enabled
(
self
)
->
bool
:
return
bool
(
self
.
lora_config
)
return
bool
(
self
.
lora_config
)
...
@@ -320,7 +341,7 @@ class Scheduler:
...
@@ -320,7 +341,7 @@ class Scheduler:
for
seq_group
in
state_queue
:
for
seq_group
in
state_queue
:
if
not
request_ids
:
if
not
request_ids
:
# Using 'break' here may add two extra iterations,
# Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity
.
# but is acceptable to reduce complexity.
break
break
if
seq_group
.
request_id
in
request_ids
:
if
seq_group
.
request_id
in
request_ids
:
# Appending aborted group into pending list.
# Appending aborted group into pending list.
...
@@ -386,15 +407,13 @@ class Scheduler:
...
@@ -386,15 +407,13 @@ class Scheduler:
# groups to preempt.
# groups to preempt.
now
=
time
.
time
()
now
=
time
.
time
()
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
running_queue
=
policy
.
sort_by_priority
(
now
,
running_queue
)
while
running_queue
:
while
running_queue
:
seq_group
=
running_queue
[
0
]
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
num_running_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
# We can have up to 1 running prefill at any given time in running
if
num_running_tokens
==
0
:
# queue, which means we can guarantee chunk size is at least 1.
break
assert
num_running_tokens
!=
0
running_queue
.
popleft
()
running_queue
.
popleft
()
while
not
self
.
_can_append_slots
(
seq_group
):
while
not
self
.
_can_append_slots
(
seq_group
):
...
@@ -449,9 +468,6 @@ class Scheduler:
...
@@ -449,9 +468,6 @@ class Scheduler:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
if
curr_loras
is
not
None
and
seq_group
.
lora_int_id
>
0
:
curr_loras
.
add
(
seq_group
.
lora_int_id
)
curr_loras
.
add
(
seq_group
.
lora_int_id
)
# Make sure all queues are updated.
assert
len
(
running_queue
)
==
0
return
running_queue
,
SchedulerRunningOutputs
(
return
running_queue
,
SchedulerRunningOutputs
(
decode_seq_groups
=
decode_seq_groups
,
decode_seq_groups
=
decode_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
prefill_seq_groups
=
prefill_seq_groups
,
...
@@ -500,14 +516,26 @@ class Scheduler:
...
@@ -500,14 +516,26 @@ class Scheduler:
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
now
=
time
.
time
()
now
=
time
.
time
()
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
swapped_queue
=
policy
.
sort_by_priority
(
now
,
swapped_queue
)
infeasible_seq_groups
:
List
[
SequenceGroup
]
=
[]
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
leftover_swapped
:
Deque
[
SequenceGroup
]
=
deque
()
while
swapped_queue
:
while
swapped_queue
:
seq_group
=
swapped_queue
[
0
]
seq_group
=
swapped_queue
[
0
]
# If the sequence group cannot be swapped in, stop.
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
alloc_status
=
self
.
block_manager
.
can_swap_in
(
seq_group
)
if
alloc_status
==
AllocStatus
.
LATER
:
break
break
elif
alloc_status
==
AllocStatus
.
NEVER
:
logger
.
warning
(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence."
,
seq_group
.
request_id
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
infeasible_seq_groups
.
append
(
seq_group
)
swapped_queue
.
popleft
()
continue
lora_int_id
=
0
lora_int_id
=
0
if
self
.
lora_enabled
:
if
self
.
lora_enabled
:
...
@@ -545,7 +573,6 @@ class Scheduler:
...
@@ -545,7 +573,6 @@ class Scheduler:
ScheduledSequenceGroup
(
seq_group
,
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
num_new_tokens
))
token_chunk_size
=
num_new_tokens
))
else
:
else
:
assert
num_new_tokens
==
1
decode_seq_groups
.
append
(
decode_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
1
))
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
1
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
...
@@ -559,7 +586,9 @@ class Scheduler:
...
@@ -559,7 +586,9 @@ class Scheduler:
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
))
is_prefill
=
False
),
infeasible_seq_groups
=
infeasible_seq_groups
,
)
def
_schedule_prefills
(
def
_schedule_prefills
(
self
,
self
,
...
@@ -617,8 +646,9 @@ class Scheduler:
...
@@ -617,8 +646,9 @@ class Scheduler:
if
num_new_tokens
>
self
.
prompt_limit
:
if
num_new_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_new_tokens
}
tokens) is too long"
"Input prompt (%d tokens) is too long"
f
" and exceeds limit of
{
self
.
prompt_limit
}
"
)
" and exceeds limit of %d"
,
num_new_tokens
,
self
.
prompt_limit
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
...
@@ -631,8 +661,9 @@ class Scheduler:
...
@@ -631,8 +661,9 @@ class Scheduler:
break
break
elif
can_allocate
==
AllocStatus
.
NEVER
:
elif
can_allocate
==
AllocStatus
.
NEVER
:
logger
.
warning
(
logger
.
warning
(
f
"Input prompt (
{
num_new_tokens
}
tokens) is too long"
"Input prompt (%d tokens) is too long"
f
" and exceeds the capacity of block_manager"
)
" and exceeds the capacity of block_manager"
,
num_new_tokens
)
for
seq
in
waiting_seqs
:
for
seq
in
waiting_seqs
:
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
seq
.
status
=
SequenceStatus
.
FINISHED_IGNORED
ignored_seq_groups
.
append
(
seq_group
)
ignored_seq_groups
.
append
(
seq_group
)
...
@@ -765,8 +796,10 @@ class Scheduler:
...
@@ -765,8 +796,10 @@ class Scheduler:
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
blocks_to_copy
=
merge_dicts
(
running_scheduled
.
blocks_to_copy
,
swapped_in
.
blocks_to_copy
),
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
+
swapped_in
.
infeasible_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
)
def
_schedule_chunked_prefill
(
self
):
def
_schedule_chunked_prefill
(
self
):
...
@@ -853,6 +886,7 @@ class Scheduler:
...
@@ -853,6 +886,7 @@ class Scheduler:
swapped_in
.
blocks_to_copy
),
swapped_in
.
blocks_to_copy
),
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
ignored_seq_groups
=
prefills
.
ignored_seq_groups
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
num_lookahead_slots
=
running_scheduled
.
num_lookahead_slots
,
running_queue_size
=
len
(
self
.
running
),
)
)
def
_schedule
(
self
)
->
SchedulerOutputs
:
def
_schedule
(
self
)
->
SchedulerOutputs
:
...
@@ -866,6 +900,13 @@ class Scheduler:
...
@@ -866,6 +900,13 @@ class Scheduler:
"""Determine whether or not we have enough space in the KV cache to
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
continue generation of the sequence group.
"""
"""
# It is True only for testing case to trigger artificial preemption.
if
(
self
.
enable_artificial_preemption
and
random
.
uniform
(
0
,
1
)
<
ARTIFICIAL_PREEMPTION_PROB
and
self
.
artificial_preempt_cnt
>
0
):
self
.
artificial_preempt_cnt
-=
1
return
False
# Appending slots only occurs in decoding.
# Appending slots only occurs in decoding.
is_prefill
=
False
is_prefill
=
False
...
@@ -874,15 +915,6 @@ class Scheduler:
...
@@ -874,15 +915,6 @@ class Scheduler:
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
)
def
_can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# Swapping in is considered decode.
is_prefill
=
False
return
self
.
block_manager
.
can_swap_in
(
seq_group
=
seq_group
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
# Schedule sequence groups.
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
...
@@ -913,6 +945,20 @@ class Scheduler:
...
@@ -913,6 +945,20 @@ class Scheduler:
self
.
block_manager
.
get_common_computed_block_ids
(
self
.
block_manager
.
get_common_computed_block_ids
(
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)))
do_sample
=
True
if
seq_group
.
is_prefill
():
seqs
=
seq_group
.
get_seqs
()
# Prefill has only 1 sequence.
assert
len
(
seqs
)
==
1
# In the next iteration, all prompt tokens are not computed.
# It means the prefill is chunked, and we don't need sampling.
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if
(
token_chunk_size
+
seqs
[
0
].
data
.
get_num_computed_tokens
()
<
seqs
[
0
].
data
.
get_len
()):
do_sample
=
False
# It assumes the scheduled_seq_groups is ordered by
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
# prefill < decoding.
is_prompt
=
seq_group
.
is_prefill
()
is_prompt
=
seq_group
.
is_prefill
()
...
@@ -922,6 +968,7 @@ class Scheduler:
...
@@ -922,6 +968,7 @@ class Scheduler:
seq_data
=
seq_data
,
seq_data
=
seq_data
,
sampling_params
=
seq_group
.
sampling_params
,
sampling_params
=
seq_group
.
sampling_params
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
do_sample
=
do_sample
,
token_chunk_size
=
token_chunk_size
,
token_chunk_size
=
token_chunk_size
,
lora_request
=
seq_group
.
lora_request
,
lora_request
=
seq_group
.
lora_request
,
computed_block_nums
=
common_computed_block_nums
,
computed_block_nums
=
common_computed_block_nums
,
...
@@ -1099,11 +1146,14 @@ class Scheduler:
...
@@ -1099,11 +1146,14 @@ class Scheduler:
if `enable_chunking` is True. If a sequence group has multiple
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
"""
"""
num_new_tokens
=
0
num_new_tokens
=
0
seqs
=
seq_group
.
get_seqs
(
status
=
status
)
seqs
=
seq_group
.
get_seqs
(
status
=
status
)
for
seq
in
seqs
:
for
seq
in
seqs
:
num_new_tokens
+=
seq
.
get_num_new_tokens
()
num_new_tokens
+=
seq
.
get_num_new_tokens
()
assert
num_new_tokens
>
0
# Chunk if a running request cannot fit in.
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
# decode phase. Do not chunk in that case.
...
...
vllm/distributed/communication_op.py
View file @
1591c68f
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
.parallel_state
import
(
get_tensor_model_parallel_group
,
from
.parallel_state
import
(
get_cpu_world_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
is_pynccl_enabled_for_all_reduce
)
is_pynccl_enabled_for_all_reduce
)
...
@@ -33,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...
@@ -33,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
if
out
is
not
None
:
if
out
is
not
None
:
return
out
return
out
if
is_pynccl_enabled_for_all_reduce
():
if
is_pynccl_enabled_for_all_reduce
():
# TODO: support multiple parallel groups.
pynccl_utils
.
all_reduce
(
input_
)
pynccl_utils
.
all_reduce
(
input_
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
torch
.
distributed
.
all_reduce
(
input_
,
...
@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any],
...
@@ -140,13 +140,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
=
[]
tensor_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
((
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
def
broadcast_tensor_dict
(
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary."""
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
group
=
group
or
torch
.
distributed
.
group
.
WORLD
metadata_group
=
metadata_group
or
get_cpu_world_group
()
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
...
@@ -161,27 +194,20 @@ def broadcast_tensor_dict(
...
@@ -161,27 +194,20 @@ def broadcast_tensor_dict(
assert
isinstance
(
assert
isinstance
(
tensor_dict
,
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
for
key
,
value
in
tensor_dict
.
items
():
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
if
isinstance
(
value
,
torch
.
Tensor
):
# `metadata_list` lives in CPU memory.
assert
value
.
is_cuda
,
(
# `broadcast_object_list` involves serialization and deserialization,
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
# all happening on CPU. Therefore, we can use the CPU group.
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
async_handles
=
[]
async_handles
=
[]
for
key
,
value
in
metadata_list
:
for
tensor
in
tensor_list
:
if
isinstance
(
value
,
TensorMetadata
):
async_handles
.
append
(
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
async_handles
.
append
(
src
=
src
,
torch
.
distributed
.
broadcast
(
tensor
,
group
=
group
,
src
=
src
,
async_op
=
True
))
group
=
group
,
async_op
=
True
))
for
async_handle
in
async_handles
:
for
async_handle
in
async_handles
:
async_handle
.
wait
()
async_handle
.
wait
()
...
@@ -189,7 +215,7 @@ def broadcast_tensor_dict(
...
@@ -189,7 +215,7 @@ def broadcast_tensor_dict(
recv_metadata_list
=
[
None
]
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
src
=
src
,
group
=
group
)
group
=
metadata_
group
)
assert
recv_metadata_list
[
0
]
is
not
None
assert
recv_metadata_list
[
0
]
is
not
None
tensor_dict
=
{}
tensor_dict
=
{}
async_handles
=
[]
async_handles
=
[]
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
1591c68f
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
try
:
try
:
...
@@ -37,7 +37,7 @@ def init_custom_ar() -> None:
...
@@ -37,7 +37,7 @@ def init_custom_ar() -> None:
return
return
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
logger
.
warn
(
logger
.
warn
ing
(
"Custom allreduce is disabled due to an unsupported world size: "
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
,
world_size
,
" disable_custom_all_reduce=True explicitly."
,
world_size
,
...
@@ -47,22 +47,22 @@ def init_custom_ar() -> None:
...
@@ -47,22 +47,22 @@ def init_custom_ar() -> None:
# note: num dev can be larger than world_size if we're only using
# note: num dev can be larger than world_size if we're only using
# first few GPUs
# first few GPUs
if
num_dev
<
world_size
:
if
num_dev
<
world_size
:
logger
.
warn
(
logger
.
warn
ing
(
"Cannot test GPU P2P because not all GPUs are visible to the "
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set."
)
" is set."
)
return
return
# test nvlink first, this will filter out most of the cases
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# where custom allreduce is not supported
if
"
CUDA_VISIBLE_DEVICES
"
in
os
.
environ
:
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
device_ids
=
list
(
if
cuda_visible_devices
:
map
(
int
,
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
.
split
(
","
)))
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
else
:
device_ids
=
list
(
range
(
num_dev
))
device_ids
=
list
(
range
(
num_dev
))
# this checks hardware and driver support for NVLink
# this checks hardware and driver support for NVLink
full_nvlink
=
_is_full_nvlink
(
device_ids
)
full_nvlink
=
_is_full_nvlink
(
device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warn
(
logger
.
warn
ing
(
"Custom allreduce is disabled because it's not supported on more"
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
" disable_custom_all_reduce=True explicitly."
)
...
@@ -71,7 +71,7 @@ def init_custom_ar() -> None:
...
@@ -71,7 +71,7 @@ def init_custom_ar() -> None:
# this is expensive to compute at the first time
# this is expensive to compute at the first time
# then we cache the result
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warn
(
logger
.
warn
ing
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
" disable_custom_all_reduce=True explicitly."
)
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
1591c68f
...
@@ -43,15 +43,16 @@ try:
...
@@ -43,15 +43,16 @@ try:
nccl
=
ctypes
.
CDLL
(
so_file
)
nccl
=
ctypes
.
CDLL
(
so_file
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
logger
.
error
(
f
"Failed to load NCCL library from
{
so_file
}
."
"Failed to load NCCL library from
%s
."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"Otherwise, the nccl library might not exist, be corrupted "
f
"or it does not support the current platform
{
platform
.
platform
()
}
."
"or it does not support the current platform %s."
f
"One solution is to download libnccl2 version 2.18 from "
"One solution is to download libnccl2 version 2.18 from "
f
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
f
"and extract the libnccl.so.2 file. If you already have the "
"and extract the libnccl.so.2 file. If you already have the "
f
"library, please set the environment variable VLLM_NCCL_SO_PATH"
"library, please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path."
)
" to point to the correct nccl library path."
,
so_file
,
platform
.
platform
())
raise
e
raise
e
# === export types and functions from nccl to Python ===
# === export types and functions from nccl to Python ===
...
@@ -199,6 +200,10 @@ _c_ncclAllReduce.argtypes = [
...
@@ -199,6 +200,10 @@ _c_ncclAllReduce.argtypes = [
ncclDataType_t
,
ctypes
.
c_void_p
,
ctypes
.
c_void_p
ncclDataType_t
,
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
]
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
# it is better not to call it at all.
# equivalent to c declaration:
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy
=
nccl
.
ncclCommDestroy
_c_ncclCommDestroy
=
nccl
.
ncclCommDestroy
...
@@ -227,6 +232,7 @@ class NCCLCommunicator:
...
@@ -227,6 +232,7 @@ class NCCLCommunicator:
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"NCCLCommunicator should be attached to a non-NCCL group."
)
"NCCLCommunicator should be attached to a non-NCCL group."
)
self
.
group
=
group
self
.
group
=
group
# note: this rank is the rank in the group
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
rank
=
dist
.
get_rank
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
self
.
world_size
=
dist
.
get_world_size
(
group
)
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
...
@@ -234,7 +240,9 @@ class NCCLCommunicator:
...
@@ -234,7 +240,9 @@ class NCCLCommunicator:
else
:
else
:
self
.
unique_id
=
NcclUniqueId
()
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
))
dist
.
broadcast
(
tensor
,
src
=
0
,
group
=
group
)
ranks
=
dist
.
get_process_group_ranks
(
group
)
# arg `src` in `broadcast` is the global rank
dist
.
broadcast
(
tensor
,
src
=
ranks
[
0
],
group
=
group
)
byte_list
=
tensor
.
tolist
()
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
unique_id
.
internal
[
i
]
=
byte
...
@@ -250,15 +258,13 @@ class NCCLCommunicator:
...
@@ -250,15 +258,13 @@ class NCCLCommunicator:
assert
isinstance
(
device
,
torch
.
device
)
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
self
.
device
=
device
# nccl communicator and stream will use this device
# nccl communicator and stream will use this device
current_device
=
torch
.
cuda
.
current_device
()
# `torch.cuda.device` is a context manager that changes the
try
:
# current cuda device to the specified one
torch
.
cuda
.
set_
device
(
device
)
with
torch
.
cuda
.
device
(
device
)
:
NCCL_CHECK
(
NCCL_CHECK
(
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
self
.
world_size
,
self
.
unique_id
,
self
.
rank
))
self
.
unique_id
,
self
.
rank
))
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
stream
=
torch
.
cuda
.
Stream
()
finally
:
torch
.
cuda
.
set_device
(
current_device
)
def
all_reduce
(
self
,
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
...
@@ -279,11 +285,3 @@ class NCCLCommunicator:
...
@@ -279,11 +285,3 @@ class NCCLCommunicator:
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
ncclDataTypeEnum
.
from_torch
(
tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
)))
ctypes
.
c_void_p
(
stream
.
cuda_stream
)))
def
__del__
(
self
):
# `dist` module might have been already destroyed
if
hasattr
(
dist
,
'destroy_process_group'
):
dist
.
destroy_process_group
()
# function might have been already destroyed
if
_c_ncclCommDestroy
is
not
None
:
_c_ncclCommDestroy
(
self
.
comm
)
vllm/distributed/device_communicators/pynccl_utils.py
View file @
1591c68f
...
@@ -14,7 +14,7 @@ try:
...
@@ -14,7 +14,7 @@ try:
except
Exception
as
e
:
except
Exception
as
e
:
# in non-NVIDIA environments, we can't import the nccl module
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
# e.g. when running on machines with AMD GPUs
logger
.
info
(
f
"Failed to import NCCL library:
{
e
}
"
)
logger
.
info
(
"Failed to import NCCL library:
%s"
,
e
)
logger
.
info
(
"It is expected if you are not running on NVIDIA GPUs."
)
logger
.
info
(
"It is expected if you are not running on NVIDIA GPUs."
)
pass
pass
...
@@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
...
@@ -40,7 +40,7 @@ def set_pynccl_stream(stream: torch.cuda.Stream):
def
init_process_group
(
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
def
init_process_group
(
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
assert
not
is_initialized
()
assert
not
is_initialized
()
global
comm
global
comm
logger
.
info
(
f
"vLLM is using nccl==
{
ncclGetVersion
()
}
"
)
logger
.
info
(
"vLLM is using nccl==
%s"
,
ncclGetVersion
())
comm
=
NCCLCommunicator
(
group
=
group
)
comm
=
NCCLCommunicator
(
group
=
group
)
...
...
vllm/distributed/parallel_state.py
View file @
1591c68f
...
@@ -4,17 +4,18 @@
...
@@ -4,17 +4,18 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups."""
"""Tensor and pipeline parallel groups."""
import
contextlib
import
contextlib
import
os
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Tensor model parallel group that the current rank belongs to.
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TP_DEVICE_GROUP
=
None
_TP_CPU_GROUP
=
None
# Pipeline model parallel group that the current rank belongs to.
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
...
@@ -57,8 +58,10 @@ def init_distributed_environment(
...
@@ -57,8 +58,10 @@ def init_distributed_environment(
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
backend
:
str
=
"nccl"
,
):
):
logger
.
debug
(
f
"
{
world_size
=
}
{
rank
=
}
{
local_rank
=
}
"
logger
.
debug
(
f
"
{
distributed_init_method
=
}
{
backend
=
}
"
)
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
distributed_init_method
,
backend
)
if
not
torch
.
distributed
.
is_initialized
():
if
not
torch
.
distributed
.
is_initialized
():
assert
distributed_init_method
is
not
None
,
(
assert
distributed_init_method
is
not
None
,
(
"distributed_init_method must be provided when initializing "
"distributed_init_method must be provided when initializing "
...
@@ -78,7 +81,7 @@ def init_distributed_environment(
...
@@ -78,7 +81,7 @@ def init_distributed_environment(
# local_rank is not available in torch ProcessGroup,
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
# see https://github.com/pytorch/pytorch/issues/122816
if
local_rank
==
-
1
and
distributed_init_method
==
"env://"
:
if
local_rank
==
-
1
and
distributed_init_method
==
"env://"
:
local_rank
=
int
(
os
.
environ
[
'
LOCAL_RANK
'
])
local_rank
=
envs
.
LOCAL_RANK
global
_LOCAL_RANK
global
_LOCAL_RANK
_LOCAL_RANK
=
local_rank
_LOCAL_RANK
=
local_rank
...
@@ -130,15 +133,17 @@ def initialize_model_parallel(
...
@@ -130,15 +133,17 @@ def initialize_model_parallel(
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_T
ENSOR_MODEL_PARALLEL
_GROUP
global
_T
P_DEVICE_GROUP
,
_TP_CPU
_GROUP
assert
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
None
,
(
assert
_T
P_DEVICE
_GROUP
is
None
,
(
"tensor model parallel group is already initialized"
)
"tensor model parallel group is already initialized"
)
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
backend
)
cpu_group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TP_DEVICE_GROUP
=
group
_TP_CPU_GROUP
=
cpu_group
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
...
@@ -183,7 +188,7 @@ def ensure_model_parallel_initialized(
...
@@ -183,7 +188,7 @@ def ensure_model_parallel_initialized(
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
not
None
return
(
_T
P_DEVICE
_GROUP
is
not
None
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
and
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
)
...
@@ -195,9 +200,16 @@ def get_cpu_world_group():
...
@@ -195,9 +200,16 @@ def get_cpu_world_group():
def
get_tensor_model_parallel_group
():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_T
ENSOR_MODEL_PARALLEL
_GROUP
is
not
None
,
(
assert
_T
P_DEVICE
_GROUP
is
not
None
,
(
"tensor model parallel group is not initialized"
)
"tensor model parallel group is not initialized"
)
return
_TENSOR_MODEL_PARALLEL_GROUP
return
_TP_DEVICE_GROUP
def
get_tensor_model_parallel_cpu_group
():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert
_TP_CPU_GROUP
is
not
None
,
(
"tensor model parallel cpu group is not initialized"
)
return
_TP_CPU_GROUP
def
get_pipeline_model_parallel_group
():
def
get_pipeline_model_parallel_group
():
...
@@ -275,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
...
@@ -275,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none and destroy them."""
"""Set the groups to none and destroy them."""
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TP_DEVICE_GROUP
if
_TENSOR_MODEL_PARALLEL_GROUP
:
if
_TP_DEVICE_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TENSOR_MODEL_PARALLEL_GROUP
)
torch
.
distributed
.
destroy_process_group
(
_TP_DEVICE_GROUP
)
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TP_DEVICE_GROUP
=
None
global
_TP_CPU_GROUP
if
_TP_CPU_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TP_CPU_GROUP
)
_TP_CPU_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
global
_PIPELINE_MODEL_PARALLEL_GROUP
if
_PIPELINE_MODEL_PARALLEL_GROUP
:
if
_PIPELINE_MODEL_PARALLEL_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_PIPELINE_MODEL_PARALLEL_GROUP
)
torch
.
distributed
.
destroy_process_group
(
_PIPELINE_MODEL_PARALLEL_GROUP
)
...
...
vllm/distributed/utils.py
View file @
1591c68f
...
@@ -9,6 +9,7 @@ from typing import Dict, Optional, Sequence
...
@@ -9,6 +9,7 @@ from typing import Dict, Optional, Sequence
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
.parallel_state
import
get_cpu_world_group
,
get_local_rank
...
@@ -102,17 +103,19 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
...
@@ -102,17 +103,19 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
is_distributed
=
dist
.
is_initialized
()
is_distributed
=
dist
.
is_initialized
()
num_dev
=
torch
.
cuda
.
device_count
()
num_dev
=
torch
.
cuda
.
device_count
()
cuda_visible_devices
=
os
.
environ
.
get
(
"
CUDA_VISIBLE_DEVICES
"
,
None
)
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
is
None
:
if
cuda_visible_devices
is
None
:
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
VLLM_CONFIG_ROOT
=
envs
.
VLLM_CONFIG_ROOT
path
=
os
.
path
.
expanduser
(
path
=
os
.
path
.
expanduser
(
f
"~/.config/vllm/gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
f
"
{
VLLM_CONFIG_ROOT
}
/vllm/gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
if
(
not
is_distributed
or
get_local_rank
()
==
0
)
\
if
(
not
is_distributed
or
get_local_rank
()
==
0
)
\
and
(
not
os
.
path
.
exists
(
path
)):
and
(
not
os
.
path
.
exists
(
path
)):
# only the local master process (with local_rank == 0) can
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
# enter this block to calculate the cache
logger
.
info
(
f
"generating GPU P2P access cache for in
{
path
}
"
)
logger
.
info
(
"generating GPU P2P access cache for in
%s"
,
path
)
cache
=
{}
cache
=
{}
for
_i
in
range
(
num_dev
):
for
_i
in
range
(
num_dev
):
for
_j
in
range
(
num_dev
):
for
_j
in
range
(
num_dev
):
...
@@ -126,7 +129,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
...
@@ -126,7 +129,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
if
is_distributed
:
if
is_distributed
:
cpu_world_group
=
get_cpu_world_group
()
cpu_world_group
=
get_cpu_world_group
()
dist
.
barrier
(
cpu_world_group
)
dist
.
barrier
(
cpu_world_group
)
logger
.
info
(
f
"reading GPU P2P access cache from
{
path
}
"
)
logger
.
info
(
"reading GPU P2P access cache from
%s"
,
path
)
with
open
(
path
,
"r"
)
as
f
:
with
open
(
path
,
"r"
)
as
f
:
cache
=
json
.
load
(
f
)
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
_gpu_p2p_access_cache
=
cache
...
...
vllm/engine/arg_utils.py
View file @
1591c68f
import
argparse
import
argparse
import
dataclasses
import
dataclasses
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
List
,
Optional
,
Union
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
...
@@ -11,10 +11,17 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...
@@ -11,10 +11,17 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from
vllm.utils
import
str_to_int_tuple
from
vllm.utils
import
str_to_int_tuple
def
nullable_str
(
val
:
str
):
if
not
val
or
val
==
"None"
:
return
None
return
val
@
dataclass
@
dataclass
class
EngineArgs
:
class
EngineArgs
:
"""Arguments for vLLM engine."""
"""Arguments for vLLM engine."""
model
:
str
model
:
str
served_model_name
:
Optional
[
Union
[
List
[
str
]]]
=
None
tokenizer
:
Optional
[
str
]
=
None
tokenizer
:
Optional
[
str
]
=
None
skip_tokenizer_init
:
bool
=
False
skip_tokenizer_init
:
bool
=
False
tokenizer_mode
:
str
=
'auto'
tokenizer_mode
:
str
=
'auto'
...
@@ -44,7 +51,8 @@ class EngineArgs:
...
@@ -44,7 +51,8 @@ class EngineArgs:
tokenizer_revision
:
Optional
[
str
]
=
None
tokenizer_revision
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
quantization
:
Optional
[
str
]
=
None
enforce_eager
:
bool
=
False
enforce_eager
:
bool
=
False
max_context_len_to_capture
:
int
=
8192
max_context_len_to_capture
:
Optional
[
int
]
=
None
max_seq_len_to_capture
:
int
=
8192
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
tokenizer_pool_size
:
int
=
0
tokenizer_pool_size
:
int
=
0
tokenizer_pool_type
:
str
=
"ray"
tokenizer_pool_type
:
str
=
"ray"
...
@@ -52,6 +60,7 @@ class EngineArgs:
...
@@ -52,6 +60,7 @@ class EngineArgs:
enable_lora
:
bool
=
False
enable_lora
:
bool
=
False
max_loras
:
int
=
1
max_loras
:
int
=
1
max_lora_rank
:
int
=
16
max_lora_rank
:
int
=
16
fully_sharded_loras
:
bool
=
False
lora_extra_vocab_size
:
int
=
256
lora_extra_vocab_size
:
int
=
256
lora_dtype
=
'auto'
lora_dtype
=
'auto'
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
...
@@ -74,6 +83,8 @@ class EngineArgs:
...
@@ -74,6 +83,8 @@ class EngineArgs:
speculative_model
:
Optional
[
str
]
=
None
speculative_model
:
Optional
[
str
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
num_speculative_tokens
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
speculative_max_model_len
:
Optional
[
int
]
=
None
ngram_prompt_lookup_max
:
Optional
[
int
]
=
None
ngram_prompt_lookup_min
:
Optional
[
int
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer
is
None
:
if
self
.
tokenizer
is
None
:
...
@@ -92,7 +103,7 @@ class EngineArgs:
...
@@ -92,7 +103,7 @@ class EngineArgs:
help
=
'Name or path of the huggingface model to use.'
)
help
=
'Name or path of the huggingface model to use.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--tokenizer'
,
'--tokenizer'
,
type
=
str
,
type
=
nullable_
str
,
default
=
EngineArgs
.
tokenizer
,
default
=
EngineArgs
.
tokenizer
,
help
=
'Name or path of the huggingface tokenizer to use.'
)
help
=
'Name or path of the huggingface tokenizer to use.'
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -101,21 +112,21 @@ class EngineArgs:
...
@@ -101,21 +112,21 @@ class EngineArgs:
help
=
'Skip initialization of tokenizer and detokenizer'
)
help
=
'Skip initialization of tokenizer and detokenizer'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--revision'
,
'--revision'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
help
=
'The specific model version to use. It can be a branch '
help
=
'The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.'
)
'the default version.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--code-revision'
,
'--code-revision'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
help
=
'The specific revision to use for the model code on '
help
=
'The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.'
)
'commit id. If unspecified, will use the default version.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--tokenizer-revision'
,
'--tokenizer-revision'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
help
=
'The specific tokenizer version to use. It can be a branch '
help
=
'The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'name, a tag name, or a commit id. If unspecified, will use '
...
@@ -132,7 +143,7 @@ class EngineArgs:
...
@@ -132,7 +143,7 @@ class EngineArgs:
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Trust remote code from huggingface.'
)
help
=
'Trust remote code from huggingface.'
)
parser
.
add_argument
(
'--download-dir'
,
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
type
=
nullable_
str
,
default
=
EngineArgs
.
download_dir
,
default
=
EngineArgs
.
download_dir
,
help
=
'Directory to download and load the weights, '
help
=
'Directory to download and load the weights, '
'default to the default cache dir of '
'default to the default cache dir of '
...
@@ -183,7 +194,7 @@ class EngineArgs:
...
@@ -183,7 +194,7 @@ class EngineArgs:
'supported for common inference criteria.'
)
'supported for common inference criteria.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--quantization-param-path'
,
'--quantization-param-path'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
help
=
'Path to the JSON file containing the KV cache '
help
=
'Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when '
'scaling factors. This should generally be supplied, when '
...
@@ -300,7 +311,7 @@ class EngineArgs:
...
@@ -300,7 +311,7 @@ class EngineArgs:
# Quantization settings.
# Quantization settings.
parser
.
add_argument
(
'--quantization'
,
parser
.
add_argument
(
'--quantization'
,
'-q'
,
'-q'
,
type
=
str
,
type
=
nullable_
str
,
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
EngineArgs
.
quantization
,
default
=
EngineArgs
.
quantization
,
help
=
'Method used to quantize the weights. If '
help
=
'Method used to quantize the weights. If '
...
@@ -319,6 +330,14 @@ class EngineArgs:
...
@@ -319,6 +330,14 @@ class EngineArgs:
default
=
EngineArgs
.
max_context_len_to_capture
,
default
=
EngineArgs
.
max_context_len_to_capture
,
help
=
'Maximum context length covered by CUDA '
help
=
'Maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'(DEPRECATED. Use --max-seq_len-to-capture instead'
')'
)
parser
.
add_argument
(
'--max-seq_len-to-capture'
,
type
=
int
,
default
=
EngineArgs
.
max_seq_len_to_capture
,
help
=
'Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.'
)
'larger than this, we fall back to eager mode.'
)
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
parser
.
add_argument
(
'--disable-custom-all-reduce'
,
action
=
'store_true'
,
action
=
'store_true'
,
...
@@ -337,7 +356,7 @@ class EngineArgs:
...
@@ -337,7 +356,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.'
)
'if tokenizer_pool_size is 0.'
)
parser
.
add_argument
(
'--tokenizer-pool-extra-config'
,
parser
.
add_argument
(
'--tokenizer-pool-extra-config'
,
type
=
str
,
type
=
nullable_
str
,
default
=
EngineArgs
.
tokenizer_pool_extra_config
,
default
=
EngineArgs
.
tokenizer_pool_extra_config
,
help
=
'Extra config for tokenizer pool. '
help
=
'Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'This should be a JSON string that will be '
...
@@ -376,6 +395,14 @@ class EngineArgs:
...
@@ -376,6 +395,14 @@ class EngineArgs:
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
help
=
(
'Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'
))
'Defaults to max_num_seqs.'
))
parser
.
add_argument
(
'--fully-sharded-loras'
,
action
=
'store_true'
,
help
=
(
'By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'
))
parser
.
add_argument
(
"--device"
,
parser
.
add_argument
(
"--device"
,
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
device
,
default
=
EngineArgs
.
device
,
...
@@ -384,7 +411,7 @@ class EngineArgs:
...
@@ -384,7 +411,7 @@ class EngineArgs:
# Related to Vision-language models such as llava
# Related to Vision-language models such as llava
parser
.
add_argument
(
parser
.
add_argument
(
'--image-input-type'
,
'--image-input-type'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
choices
=
[
choices
=
[
t
.
name
.
lower
()
for
t
in
VisionLanguageConfig
.
ImageInputType
t
.
name
.
lower
()
for
t
in
VisionLanguageConfig
.
ImageInputType
...
@@ -397,7 +424,7 @@ class EngineArgs:
...
@@ -397,7 +424,7 @@ class EngineArgs:
help
=
(
'Input id for image token.'
))
help
=
(
'Input id for image token.'
))
parser
.
add_argument
(
parser
.
add_argument
(
'--image-input-shape'
,
'--image-input-shape'
,
type
=
str
,
type
=
nullable_
str
,
default
=
None
,
default
=
None
,
help
=
(
'The biggest image input shape (worst for memory footprint) '
help
=
(
'The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM
\'
s profile_run.'
))
'given an input type. Only used for vLLM
\'
s profile_run.'
))
...
@@ -420,7 +447,7 @@ class EngineArgs:
...
@@ -420,7 +447,7 @@ class EngineArgs:
parser
.
add_argument
(
parser
.
add_argument
(
'--speculative-model'
,
'--speculative-model'
,
type
=
str
,
type
=
nullable_
str
,
default
=
EngineArgs
.
speculative_model
,
default
=
EngineArgs
.
speculative_model
,
help
=
help
=
'The name of the draft model to be used in speculative decoding.'
)
'The name of the draft model to be used in speculative decoding.'
)
...
@@ -434,14 +461,28 @@ class EngineArgs:
...
@@ -434,14 +461,28 @@ class EngineArgs:
parser
.
add_argument
(
parser
.
add_argument
(
'--speculative-max-model-len'
,
'--speculative-max-model-len'
,
type
=
str
,
type
=
int
,
default
=
EngineArgs
.
speculative_max_model_len
,
default
=
EngineArgs
.
speculative_max_model_len
,
help
=
'The maximum sequence length supported by the '
help
=
'The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'draft model. Sequences over this length will skip '
'speculation.'
)
'speculation.'
)
parser
.
add_argument
(
'--ngram-prompt-lookup-max'
,
type
=
int
,
default
=
EngineArgs
.
ngram_prompt_lookup_max
,
help
=
'Max size of window for ngram prompt lookup in speculative '
'decoding.'
)
parser
.
add_argument
(
'--ngram-prompt-lookup-min'
,
type
=
int
,
default
=
EngineArgs
.
ngram_prompt_lookup_min
,
help
=
'Min size of window for ngram prompt lookup in speculative '
'decoding.'
)
parser
.
add_argument
(
'--model-loader-extra-config'
,
parser
.
add_argument
(
'--model-loader-extra-config'
,
type
=
str
,
type
=
nullable_
str
,
default
=
EngineArgs
.
model_loader_extra_config
,
default
=
EngineArgs
.
model_loader_extra_config
,
help
=
'Extra config for model loader. '
help
=
'Extra config for model loader. '
'This will be passed to the model loader '
'This will be passed to the model loader '
...
@@ -449,6 +490,21 @@ class EngineArgs:
...
@@ -449,6 +490,21 @@ class EngineArgs:
'This should be a JSON string that will be '
'This should be a JSON string that will be '
'parsed into a dictionary.'
)
'parsed into a dictionary.'
)
parser
.
add_argument
(
"--served-model-name"
,
nargs
=
"+"
,
type
=
str
,
default
=
None
,
help
=
"The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s)"
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one."
)
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -467,7 +523,8 @@ class EngineArgs:
...
@@ -467,7 +523,8 @@ class EngineArgs:
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
code_revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
quantization
,
self
.
quantization_param_path
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
enforce_eager
,
self
.
max_context_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
)
self
.
max_seq_len_to_capture
,
self
.
max_logprobs
,
self
.
skip_tokenizer_init
,
self
.
served_model_name
)
cache_config
=
CacheConfig
(
self
.
block_size
,
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
self
.
swap_space
,
self
.
kv_cache_dtype
,
...
@@ -493,6 +550,8 @@ class EngineArgs:
...
@@ -493,6 +550,8 @@ class EngineArgs:
speculative_max_model_len
=
self
.
speculative_max_model_len
,
speculative_max_model_len
=
self
.
speculative_max_model_len
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
use_v2_block_manager
=
self
.
use_v2_block_manager
,
ngram_prompt_lookup_max
=
self
.
ngram_prompt_lookup_max
,
ngram_prompt_lookup_min
=
self
.
ngram_prompt_lookup_min
,
)
)
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
...
@@ -509,6 +568,7 @@ class EngineArgs:
...
@@ -509,6 +568,7 @@ class EngineArgs:
lora_config
=
LoRAConfig
(
lora_config
=
LoRAConfig
(
max_lora_rank
=
self
.
max_lora_rank
,
max_lora_rank
=
self
.
max_lora_rank
,
max_loras
=
self
.
max_loras
,
max_loras
=
self
.
max_loras
,
fully_sharded_loras
=
self
.
fully_sharded_loras
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
lora_extra_vocab_size
=
self
.
lora_extra_vocab_size
,
lora_dtype
=
self
.
lora_dtype
,
lora_dtype
=
self
.
lora_dtype
,
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
max_cpu_loras
=
self
.
max_cpu_loras
if
self
.
max_cpu_loras
...
...
vllm/engine/async_llm_engine.py
View file @
1591c68f
import
asyncio
import
asyncio
import
os
import
time
import
time
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
from
typing
import
(
Any
,
AsyncIterator
,
Callable
,
Dict
,
Iterable
,
List
,
...
@@ -7,20 +6,21 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
...
@@ -7,20 +6,21 @@ from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
ModelConfig
import
vllm.envs
as
envs
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.e
ngine
.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.e
xecutor
.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
MultiModalData
from
vllm.sequence
import
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
int
(
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
os
.
environ
.
get
(
"VLLM_ENGINE_ITERATION_TIMEOUT_S"
,
"60"
))
class
AsyncEngineDeadError
(
RuntimeError
):
class
AsyncEngineDeadError
(
RuntimeError
):
...
@@ -117,7 +117,7 @@ class RequestTracker:
...
@@ -117,7 +117,7 @@ class RequestTracker:
self
.
_request_streams
[
request_id
].
put
(
request_output
)
self
.
_request_streams
[
request_id
].
put
(
request_output
)
if
request_output
.
finished
:
if
request_output
.
finished
:
if
verbose
:
if
verbose
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
logger
.
info
(
"Finished request
%s."
,
request_id
)
self
.
abort_request
(
request_id
)
self
.
abort_request
(
request_id
)
def
process_exception
(
self
,
def
process_exception
(
self
,
...
@@ -128,7 +128,7 @@ class RequestTracker:
...
@@ -128,7 +128,7 @@ class RequestTracker:
"""Propagate an exception from the engine."""
"""Propagate an exception from the engine."""
self
.
_request_streams
[
request_id
].
put
(
exception
)
self
.
_request_streams
[
request_id
].
put
(
exception
)
if
verbose
:
if
verbose
:
logger
.
info
(
f
"Finished request
{
request_id
}
."
)
logger
.
info
(
"Finished request
%s."
,
request_id
)
self
.
abort_request
(
request_id
)
self
.
abort_request
(
request_id
)
def
add_request
(
self
,
request_id
:
str
,
def
add_request
(
self
,
request_id
:
str
,
...
@@ -151,7 +151,7 @@ class RequestTracker:
...
@@ -151,7 +151,7 @@ class RequestTracker:
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
"""Abort a request during next background loop iteration."""
"""Abort a request during next background loop iteration."""
if
verbose
:
if
verbose
:
logger
.
info
(
f
"Aborted request
{
request_id
}
."
)
logger
.
info
(
"Aborted request
%s."
,
request_id
)
self
.
_finished_requests
.
put_nowait
(
request_id
)
self
.
_finished_requests
.
put_nowait
(
request_id
)
...
@@ -210,20 +210,25 @@ class _AsyncLLMEngine(LLMEngine):
...
@@ -210,20 +210,25 @@ class _AsyncLLMEngine(LLMEngine):
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
# Execute the model.
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
await
self
.
model_executor
.
execute_model_async
(
output
=
await
self
.
model_executor
.
execute_model_async
(
seq_group_metadata_list
,
scheduler_outputs
.
blocks_to_swap_in
,
execute_model_req
)
scheduler_outputs
.
blocks_to_swap_out
,
scheduler_outputs
.
blocks_to_copy
)
else
:
else
:
output
=
[]
output
=
[]
request_outputs
=
self
.
_process_model_outputs
(
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
)
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
# Log stats.
# Log stats.
if
self
.
log_stats
:
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
))
return
request_outputs
return
request_outputs
...
@@ -521,11 +526,11 @@ class AsyncLLMEngine:
...
@@ -521,11 +526,11 @@ class AsyncLLMEngine:
if
shortened_token_ids
is
not
None
:
if
shortened_token_ids
is
not
None
:
shortened_token_ids
=
shortened_token_ids
[:
self
.
shortened_token_ids
=
shortened_token_ids
[:
self
.
max_log_len
]
max_log_len
]
logger
.
info
(
f
"Received request
{
request_id
}
: "
logger
.
info
(
f
"prompt:
{
shortened_
prompt
!
r
}
, "
"Received request %s:
prompt
: %r
, "
f
"sampling_params:
{
sampling_params
}
, "
"sampling_params:
%s, prompt_token_ids: %s
, "
f
"prompt_token_ids:
{
shortened_token_ids
}
, "
"lora_request: %s."
,
request_id
,
shortened_prompt
,
f
"lora_request:
{
lora_request
}
."
)
sampling_params
,
shortened_token_ids
,
lora_request
)
if
not
self
.
is_running
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
if
self
.
start_engine_loop
:
...
@@ -697,9 +702,21 @@ class AsyncLLMEngine:
...
@@ -697,9 +702,21 @@ class AsyncLLMEngine:
else
:
else
:
return
self
.
engine
.
get_model_config
()
return
self
.
engine
.
get_model_config
()
async
def
do_log_stats
(
self
)
->
None
:
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Get the decoding configuration of the vLLM engine."""
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_decoding_config
.
remote
(
# type: ignore
)
else
:
return
self
.
engine
.
get_decoding_config
()
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
if
self
.
engine_use_ray
:
if
self
.
engine_use_ray
:
await
self
.
engine
.
do_log_stats
.
remote
()
# type: ignore
await
self
.
engine
.
do_log_stats
.
remote
(
# type: ignore
scheduler_outputs
,
model_output
)
else
:
else
:
self
.
engine
.
do_log_stats
()
self
.
engine
.
do_log_stats
()
...
@@ -717,4 +734,4 @@ class AsyncLLMEngine:
...
@@ -717,4 +734,4 @@ class AsyncLLMEngine:
raise
RuntimeError
(
"Engine is dead."
)
from
e
raise
RuntimeError
(
"Engine is dead."
)
from
e
else
:
else
:
await
self
.
engine
.
check_health_async
()
await
self
.
engine
.
check_health_async
()
logger
.
debug
(
f
"Health check took
{
time
.
perf_counter
()
-
t
}
s"
)
logger
.
debug
(
"Health check took
%fs"
,
time
.
perf_counter
()
-
t
)
vllm/engine/llm_engine.py
View file @
1591c68f
...
@@ -8,21 +8,23 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
...
@@ -8,21 +8,23 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
SchedulerOutputs
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.metrics
import
StatLogger
,
Stats
from
vllm.engine.output_processor.interfaces
import
(
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.engine.ray_utils
import
initialize_ray_cluster
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
Sequence
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
MultiModalData
,
SamplerOutput
,
SequenceGroup
,
SequenceStage
)
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
get_tokenizer_group
)
get_tokenizer_group
)
...
@@ -96,29 +98,39 @@ class LLMEngine:
...
@@ -96,29 +98,39 @@ class LLMEngine:
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
)
->
None
:
)
->
None
:
logger
.
info
(
logger
.
info
(
f
"Initializing an LLM engine (v
{
vllm
.
__version__
}
) with config: "
"Initializing an LLM engine (v%s) with config: "
f
"model=
{
model_config
.
model
!
r
}
, "
"model=%r, speculative_config=%r, tokenizer=%r, "
f
"speculative_config=
{
speculative_config
!
r
}
, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
f
"tokenizer=
{
model_config
.
tokenizer
!
r
}
, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
f
"skip_tokenizer_init=
{
model_config
.
skip_tokenizer_init
}
, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
f
"tokenizer_mode=
{
model_config
.
tokenizer_mode
}
, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
f
"revision=
{
model_config
.
revision
}
, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
f
"tokenizer_revision=
{
model_config
.
tokenizer_revision
}
, "
"quantization_param_path=%s, device_config=%s, "
f
"trust_remote_code=
{
model_config
.
trust_remote_code
}
, "
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
f
"dtype=
{
model_config
.
dtype
}
, "
vllm
.
__version__
,
f
"max_seq_len=
{
model_config
.
max_model_len
}
, "
model_config
.
model
,
f
"download_dir=
{
load_config
.
download_dir
!
r
}
, "
speculative_config
,
f
"load_format=
{
load_config
.
load_format
}
, "
model_config
.
tokenizer
,
f
"tensor_parallel_size=
{
parallel_config
.
tensor_parallel_size
}
, "
model_config
.
skip_tokenizer_init
,
f
"disable_custom_all_reduce="
model_config
.
tokenizer_mode
,
f
"
{
parallel_config
.
disable_custom_all_reduce
}
, "
model_config
.
revision
,
f
"quantization=
{
model_config
.
quantization
}
, "
model_config
.
tokenizer_revision
,
f
"enforce_eager=
{
model_config
.
enforce_eager
}
, "
model_config
.
trust_remote_code
,
f
"kv_cache_dtype=
{
cache_config
.
cache_dtype
}
, "
model_config
.
dtype
,
f
"quantization_param_path=
{
model_config
.
quantization_param_path
}
, "
model_config
.
max_model_len
,
f
"device_config=
{
device_config
.
device
}
, "
load_config
.
download_dir
,
f
"decoding_config=
{
decoding_config
!
r
}
, "
load_config
.
load_format
,
f
"seed=
{
model_config
.
seed
}
)"
)
parallel_config
.
tensor_parallel_size
,
parallel_config
.
disable_custom_all_reduce
,
model_config
.
quantization
,
model_config
.
enforce_eager
,
cache_config
.
cache_dtype
,
model_config
.
quantization_param_path
,
device_config
.
device
,
decoding_config
,
model_config
.
seed
,
model_config
.
served_model_name
,
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
@@ -208,7 +220,8 @@ class LLMEngine:
...
@@ -208,7 +220,8 @@ class LLMEngine:
if
self
.
log_stats
:
if
self
.
log_stats
:
self
.
stat_logger
=
StatLogger
(
self
.
stat_logger
=
StatLogger
(
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
local_interval
=
_LOCAL_LOGGING_INTERVAL_SEC
,
labels
=
dict
(
model_name
=
model_config
.
model
))
labels
=
dict
(
model_name
=
model_config
.
served_model_name
),
max_model_len
=
self
.
model_config
.
max_model_len
)
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
self
.
stat_logger
.
info
(
"cache_config"
,
self
.
cache_config
)
# Create sequence output processor, e.g. for beam search or
# Create sequence output processor, e.g. for beam search or
...
@@ -237,8 +250,10 @@ class LLMEngine:
...
@@ -237,8 +250,10 @@ class LLMEngine:
if
self
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
if
self
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
num_gpu_blocks_override
=
self
.
cache_config
.
num_gpu_blocks_override
num_gpu_blocks_override
=
self
.
cache_config
.
num_gpu_blocks_override
logger
.
info
(
f
"Overriding
{
num_gpu_blocks
=
}
with "
logger
.
info
(
f
"
{
num_gpu_blocks_override
=
}
"
)
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d"
,
num_gpu_blocks
,
num_gpu_blocks_override
)
num_gpu_blocks
=
num_gpu_blocks_override
num_gpu_blocks
=
num_gpu_blocks_override
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
...
@@ -287,6 +302,12 @@ class LLMEngine:
...
@@ -287,6 +302,12 @@ class LLMEngine:
# the closure used to initialize Ray worker actors
# the closure used to initialize Ray worker actors
raise
RuntimeError
(
"LLMEngine should not be pickled!"
)
raise
RuntimeError
(
"LLMEngine should not be pickled!"
)
def
__del__
(
self
):
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if
model_executor
:
=
getattr
(
self
,
"model_executor"
,
None
):
model_executor
.
shutdown
()
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
def
get_tokenizer
(
self
)
->
"PreTrainedTokenizer"
:
return
self
.
tokenizer
.
get_lora_tokenizer
(
None
)
return
self
.
tokenizer
.
get_lora_tokenizer
(
None
)
...
@@ -414,9 +435,10 @@ class LLMEngine:
...
@@ -414,9 +435,10 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
# this doesn't deep-copy LogitsProcessor objects
sampling_params
=
sampling_params
.
clone
()
sampling_params
=
sampling_params
.
clone
()
#
inject
the eos token id into the sampling_params to support min_tokens
#
Add
the eos token id into the sampling_params to support min_tokens
# processing
# processing
sampling_params
.
eos_token_id
=
seq
.
eos_token_id
if
seq
.
eos_token_id
is
not
None
:
sampling_params
.
all_stop_token_ids
.
add
(
seq
.
eos_token_id
)
sampling_params
.
update_from_generation_config
(
sampling_params
.
update_from_generation_config
(
self
.
generation_config_fields
)
self
.
generation_config_fields
)
...
@@ -450,6 +472,10 @@ class LLMEngine:
...
@@ -450,6 +472,10 @@ class LLMEngine:
"""Gets the model configuration."""
"""Gets the model configuration."""
return
self
.
model_config
return
self
.
model_config
def
get_decoding_config
(
self
)
->
DecodingConfig
:
"""Gets the decoding configuration."""
return
self
.
decoding_config
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Gets the number of unfinished requests."""
"""Gets the number of unfinished requests."""
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
return
self
.
scheduler
.
get_num_unfinished_seq_groups
()
...
@@ -459,9 +485,12 @@ class LLMEngine:
...
@@ -459,9 +485,12 @@ class LLMEngine:
return
self
.
scheduler
.
has_unfinished_seqs
()
return
self
.
scheduler
.
has_unfinished_seqs
()
def
_process_model_outputs
(
def
_process_model_outputs
(
self
,
output
:
List
[
SamplerOutput
],
self
,
scheduled_seq_groups
:
List
[
SequenceGroup
],
output
:
List
[
SamplerOutput
],
ignored_seq_groups
:
List
[
SequenceGroup
])
->
List
[
RequestOutput
]:
scheduled_seq_groups
:
List
[
ScheduledSequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
List
[
RequestOutput
]:
"""Apply the model output to the sequences in the scheduled seq groups.
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
Returns RequestOutputs that can be returned to the client.
...
@@ -475,17 +504,15 @@ class LLMEngine:
...
@@ -475,17 +504,15 @@ class LLMEngine:
sampler_outputs
=
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
sampler_outputs
=
output
,
num_seq_groups
=
len
(
scheduled_seq_groups
))
# Update the scheduled sequence groups with the model outputs.
# Update the scheduled sequence groups with the model outputs.
for
scheduled_seq_group
,
outputs
in
zip
(
scheduled_seq_groups
,
for
scheduled_seq_group
,
outputs
,
seq_group_meta
in
zip
(
output_by_sequence_group
):
scheduled_seq_groups
,
output_by_sequence_group
,
seq_group_metadata_list
):
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
scheduled_seq_group
.
token_chunk_size
)
# If all sequences in the sequence group are in DECODE, then we can
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
outputs
)
# process the output tokens. Otherwise, they are (chunked) prefill
if
seq_group_meta
.
do_sample
:
# samples and should not be processed.
stages
=
[
seq
.
data
.
_stage
for
seq
in
seq_group
.
seqs_dict
.
values
()]
if
all
(
stage
==
SequenceStage
.
DECODE
for
stage
in
stages
):
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
self
.
output_processor
.
process_outputs
(
seq_group
,
outputs
)
# Free the finished sequence groups.
# Free the finished sequence groups.
...
@@ -557,30 +584,36 @@ class LLMEngine:
...
@@ -557,30 +584,36 @@ class LLMEngine:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
if
not
scheduler_outputs
.
is_empty
():
output
=
self
.
model_executor
.
e
xecute
_m
odel
(
execute_model_req
=
E
xecute
M
odel
Request
(
seq_group_metadata_list
=
seq_group_metadata_list
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
)
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
)
output
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
else
:
else
:
output
=
[]
output
=
[]
request_outputs
=
self
.
_process_model_outputs
(
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
)
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
# Log stats.
# Log stats.
if
self
.
log_stats
:
self
.
do_log_stats
(
scheduler_outputs
,
output
)
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
,
model_output
=
output
))
return
request_outputs
return
request_outputs
def
do_log_stats
(
self
)
->
None
:
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
)
->
None
:
"""Forced log when no requests active."""
"""Forced log when no requests active."""
if
self
.
log_stats
:
if
self
.
log_stats
:
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
=
None
))
self
.
stat_logger
.
log
(
self
.
_get_stats
(
scheduler_outputs
,
model_output
))
def
_get_stats
(
def
_get_stats
(
self
,
self
,
...
@@ -596,59 +629,109 @@ class LLMEngine:
...
@@ -596,59 +629,109 @@ class LLMEngine:
"""
"""
now
=
time
.
time
()
now
=
time
.
time
()
# KV Cache Usage in %.
# System State
# Scheduler State
num_running_sys
=
len
(
self
.
scheduler
.
running
)
num_swapped_sys
=
len
(
self
.
scheduler
.
swapped
)
num_waiting_sys
=
len
(
self
.
scheduler
.
waiting
)
# KV Cache Usage in %
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_total_gpu
=
self
.
cache_config
.
num_gpu_blocks
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
()
num_free_gpu
=
self
.
scheduler
.
block_manager
.
get_num_free_gpu_blocks
()
gpu_cache_usage
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
gpu_cache_usage
_sys
=
1.0
-
(
num_free_gpu
/
num_total_gpu
)
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
num_total_cpu
=
self
.
cache_config
.
num_cpu_blocks
cpu_cache_usage
=
0.
cpu_cache_usage
_sys
=
0.
if
num_total_cpu
>
0
:
if
num_total_cpu
>
0
:
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
num_free_cpu
=
self
.
scheduler
.
block_manager
.
get_num_free_cpu_blocks
(
)
)
cpu_cache_usage
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
# Scheduler State
# Iteration stats
num_running
=
len
(
self
.
scheduler
.
running
)
num_prompt_tokens_iter
=
0
num_swapped
=
len
(
self
.
scheduler
.
swapped
)
num_generation_tokens_iter
=
0
num_waiting
=
len
(
self
.
scheduler
.
waiting
)
time_to_first_tokens_iter
:
List
[
float
]
=
[]
time_per_output_tokens_iter
:
List
[
float
]
=
[]
# Iteration stats if we have scheduler output.
num_prompt_tokens
=
0
# Request stats
num_generation_tokens
=
0
# Latency
time_to_first_tokens
=
[]
time_e2e_requests
:
List
[
float
]
=
[]
time_per_output_tokens
=
[]
# Metadata
time_e2e_requests
=
[]
num_prompt_tokens_requests
:
List
[
int
]
=
[]
num_generation_tokens_requests
:
List
[
int
]
=
[]
best_of_requests
:
List
[
int
]
=
[]
n_requests
:
List
[
int
]
=
[]
finished_reason_requests
:
List
[
str
]
=
[]
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
if
scheduler_outputs
is
not
None
:
if
scheduler_outputs
is
not
None
:
prompt_run
=
scheduler_outputs
.
num_prefill_groups
>
0
num_generation_tokens_from_prefill_groups
=
0.
# NOTE: if scheduler_outputs.num_prefill_groups > 0 and
# Number of Tokens.
# the len of scheduler_outputs.scheduled_seq_groups is !=
if
prompt_run
:
# scheduler_outputs.num_prefill_groups, this means that
num_prompt_tokens
=
sum
(
# chunked prefills have been detected.
len
(
scheduled_seq_group
.
seq_group
.
prompt_token_ids
)
for
scheduled_seq_group
in
for
idx
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
)
scheduler_outputs
.
scheduled_seq_groups
):
num_generation_tokens
=
sum
(
group_was_prefill
=
idx
<
scheduler_outputs
.
num_prefill_groups
scheduled_seq_group
.
seq_group
.
num_seqs
()
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
)
else
:
num_generation_tokens
=
scheduler_outputs
.
num_batched_tokens
# Latency Timings.
time_last_iters
=
[]
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
# Time since last token.
# (n.b. updates seq_group.metrics.last_token_time)
# NOTE: a seq_group that completed all of its prefill tokens
time_last_iters
.
append
(
seq_group
.
get_last_latency
(
now
))
# in the last iteration will have seq_group.is_prefill() = False
# Time since arrival for all finished requests.
# with group_was_prefill = True
if
group_was_prefill
:
# Number of prompt tokens.
num_prompt_tokens_iter
+=
(
scheduled_seq_group
.
token_chunk_size
)
# If the seq_group just finished the prefill state
# get TTFT.
if
not
seq_group
.
is_prefill
():
latency
=
seq_group
.
get_last_latency
(
now
)
time_to_first_tokens_iter
.
append
(
latency
)
# One generation token per finished prefill.
num_generation_tokens_from_prefill_groups
+=
(
seq_group
.
num_seqs
())
else
:
# TPOTs.
latency
=
seq_group
.
get_last_latency
(
now
)
time_per_output_tokens_iter
.
append
(
latency
)
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
# the same metadata more than once per request, we standardize
# on logging request level information for finished requests,
# which can only happen once.
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
():
# Latency timings
time_e2e_requests
.
append
(
now
-
time_e2e_requests
.
append
(
now
-
seq_group
.
metrics
.
arrival_time
)
seq_group
.
metrics
.
arrival_time
)
time_to_first_tokens
=
time_last_iters
if
prompt_run
else
[]
# Metadata
time_per_output_tokens
=
[]
if
prompt_run
else
time_last_iters
num_prompt_tokens_requests
.
append
(
len
(
seq_group
.
prompt_token_ids
))
num_generation_tokens_requests
.
extend
([
seq
.
get_output_len
()
for
seq
in
seq_group
.
get_finished_seqs
()
])
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
finished_reason_requests
.
extend
([
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
for
seq
in
seq_group
.
get_finished_seqs
()
])
# Number of generation tokens.
# num_batched_tokens equals the number of prompt_tokens plus the
# number of decode_tokens in a single iteration. So,
# num_generation_tokens = num_batched_tokens - num_prompt_tokens
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter
=
(
scheduler_outputs
.
num_batched_tokens
-
num_prompt_tokens_iter
+
num_generation_tokens_from_prefill_groups
)
# Spec decode, if enabled, emits specialized metrics from the worker in
# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
# sampler output.
...
@@ -660,17 +743,32 @@ class LLMEngine:
...
@@ -660,17 +743,32 @@ class LLMEngine:
return
Stats
(
return
Stats
(
now
=
now
,
now
=
now
,
num_running
=
num_running
,
num_swapped
=
num_swapped
,
# System stats
num_waiting
=
num_waiting
,
# Scheduler State
gpu_cache_usage
=
gpu_cache_usage
,
num_running_sys
=
num_running_sys
,
cpu_cache_usage
=
cpu_cache_usage
,
num_swapped_sys
=
num_swapped_sys
,
num_prompt_tokens
=
num_prompt_tokens
,
num_waiting_sys
=
num_waiting_sys
,
num_generation_tokens
=
num_generation_tokens
,
# KV Cache Usage in %
time_to_first_tokens
=
time_to_first_tokens
,
gpu_cache_usage_sys
=
gpu_cache_usage_sys
,
time_per_output_tokens
=
time_per_output_tokens
,
cpu_cache_usage_sys
=
cpu_cache_usage_sys
,
time_e2e_requests
=
time_e2e_requests
,
# Iteration stats
num_prompt_tokens_iter
=
num_prompt_tokens_iter
,
num_generation_tokens_iter
=
num_generation_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_tokens_iter
,
spec_decode_metrics
=
spec_decode_metrics
,
spec_decode_metrics
=
spec_decode_metrics
,
# Request stats
# Latency
time_e2e_requests
=
time_e2e_requests
,
# Metadata
num_prompt_tokens_requests
=
num_prompt_tokens_requests
,
num_generation_tokens_requests
=
num_generation_tokens_requests
,
best_of_requests
=
best_of_requests
,
n_requests
=
n_requests
,
finished_reason_requests
=
finished_reason_requests
,
)
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
...
...
vllm/engine/metrics.py
View file @
1591c68f
import
time
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Protocol
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Protocol
,
Union
import
numpy
as
np
import
numpy
as
np
from
prometheus_client
import
(
REGISTRY
,
Counter
,
Gauge
,
Histogram
,
Info
,
from
prometheus_client
import
(
REGISTRY
,
Counter
,
Gauge
,
Histogram
,
Info
,
...
@@ -21,8 +23,9 @@ disable_created_metrics()
...
@@ -21,8 +23,9 @@ disable_created_metrics()
# begin-metrics-definitions
# begin-metrics-definitions
class
Metrics
:
class
Metrics
:
labelname_finish_reason
=
"finished_reason"
def
__init__
(
self
,
labelnames
:
List
[
str
]):
def
__init__
(
self
,
labelnames
:
List
[
str
]
,
max_model_len
:
int
):
# Unregister any existing vLLM collectors
# Unregister any existing vLLM collectors
for
collector
in
list
(
REGISTRY
.
_collector_to_names
):
for
collector
in
list
(
REGISTRY
.
_collector_to_names
):
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
if
hasattr
(
collector
,
"_name"
)
and
"vllm"
in
collector
.
_name
:
...
@@ -34,18 +37,20 @@ class Metrics:
...
@@ -34,18 +37,20 @@ class Metrics:
documentation
=
'information of cache_config'
)
documentation
=
'information of cache_config'
)
# System stats
# System stats
# Scheduler State
self
.
gauge_scheduler_running
=
Gauge
(
self
.
gauge_scheduler_running
=
Gauge
(
name
=
"vllm:num_requests_running"
,
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
Gauge
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
self
.
gauge_scheduler_waiting
=
Gauge
(
self
.
gauge_scheduler_waiting
=
Gauge
(
name
=
"vllm:num_requests_waiting"
,
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
self
.
gauge_scheduler_swapped
=
Gauge
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
Gauge
(
self
.
gauge_gpu_cache_usage
=
Gauge
(
name
=
"vllm:gpu_cache_usage_perc"
,
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
...
@@ -55,7 +60,7 @@ class Metrics:
...
@@ -55,7 +60,7 @@ class Metrics:
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
)
#
Raw stats from last model iteration
#
Iteration stats
self
.
counter_prompt_tokens
=
Counter
(
self
.
counter_prompt_tokens
=
Counter
(
name
=
"vllm:prompt_tokens_total"
,
name
=
"vllm:prompt_tokens_total"
,
documentation
=
"Number of prefill tokens processed."
,
documentation
=
"Number of prefill tokens processed."
,
...
@@ -80,18 +85,51 @@ class Metrics:
...
@@ -80,18 +85,51 @@ class Metrics:
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
1.0
,
2.5
1.0
,
2.5
])
])
self
.
histogram_e2e_request_latency
=
Histogram
(
# Request stats
# Latency
self
.
histogram_e2e_time_request
=
Histogram
(
name
=
"vllm:e2e_request_latency_seconds"
,
name
=
"vllm:e2e_request_latency_seconds"
,
documentation
=
"Histogram of end to end request latency in seconds."
,
documentation
=
"Histogram of end to end request latency in seconds."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
buckets
=
[
1.0
,
2.5
,
5.0
,
10.0
,
15.0
,
20.0
,
30.0
,
40.0
,
50.0
,
60.0
])
# Metadata
self
.
histogram_num_prompt_tokens_request
=
Histogram
(
name
=
"vllm:request_prompt_tokens"
,
documentation
=
"Number of prefill tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_num_generation_tokens_request
=
Histogram
(
name
=
"vllm:request_generation_tokens"
,
documentation
=
"Number of generation tokens processed."
,
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
Histogram
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
Histogram
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
counter_request_success
=
Counter
(
name
=
"vllm:request_success_total"
,
documentation
=
"Count of successfully processed requests."
,
labelnames
=
labelnames
+
[
Metrics
.
labelname_finish_reason
])
#
Legacy metrics
#
Deprecated in favor of vllm:prompt_tokens_total
self
.
gauge_avg_prompt_throughput
=
Gauge
(
self
.
gauge_avg_prompt_throughput
=
Gauge
(
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
)
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
Gauge
(
self
.
gauge_avg_generation_throughput
=
Gauge
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
documentation
=
"Average generation throughput in tokens/s."
,
...
@@ -102,24 +140,57 @@ class Metrics:
...
@@ -102,24 +140,57 @@ class Metrics:
# end-metrics-definitions
# end-metrics-definitions
def
build_1_2_5_buckets
(
max_value
:
int
):
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum.
Example:
>>> build_1_2_5_buckets(100)
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst
=
[
1
,
2
,
5
]
exponent
=
0
buckets
=
[]
while
True
:
for
m
in
mantissa_lst
:
value
=
m
*
10
**
exponent
if
value
<=
max_value
:
buckets
.
append
(
value
)
else
:
return
buckets
exponent
+=
1
@
dataclass
@
dataclass
class
Stats
:
class
Stats
:
"""Created by LLMEngine for use by StatLogger."""
"""Created by LLMEngine for use by StatLogger."""
now
:
float
now
:
float
# System stats.
# System stats (should have _sys suffix)
num_running
:
int
# Scheduler State
num_waiting
:
int
num_running_sys
:
int
num_swapped
:
int
num_waiting_sys
:
int
gpu_cache_usage
:
float
num_swapped_sys
:
int
cpu_cache_usage
:
float
# KV Cache Usage in %
gpu_cache_usage_sys
:
float
# Raw stats from last model iteration.
cpu_cache_usage_sys
:
float
num_prompt_tokens
:
int
num_generation_tokens
:
int
# Iteration stats (should have _iter suffix)
time_to_first_tokens
:
List
[
float
]
num_prompt_tokens_iter
:
int
time_per_output_tokens
:
List
[
float
]
num_generation_tokens_iter
:
int
time_to_first_tokens_iter
:
List
[
float
]
time_per_output_tokens_iter
:
List
[
float
]
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests
:
List
[
float
]
time_e2e_requests
:
List
[
float
]
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
best_of_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
...
@@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol):
...
@@ -133,7 +204,8 @@ class SupportsMetricsInfo(Protocol):
class
StatLogger
:
class
StatLogger
:
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
def
__init__
(
self
,
local_interval
:
float
,
labels
:
Dict
[
str
,
str
])
->
None
:
def
__init__
(
self
,
local_interval
:
float
,
labels
:
Dict
[
str
,
str
],
max_model_len
:
int
)
->
None
:
# Metadata for logging locally.
# Metadata for logging locally.
self
.
last_local_log
=
time
.
time
()
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
local_interval
=
local_interval
...
@@ -144,7 +216,8 @@ class StatLogger:
...
@@ -144,7 +216,8 @@ class StatLogger:
# Prometheus metrics
# Prometheus metrics
self
.
labels
=
labels
self
.
labels
=
labels
self
.
metrics
=
Metrics
(
labelnames
=
list
(
labels
.
keys
()))
self
.
metrics
=
Metrics
(
labelnames
=
list
(
labels
.
keys
()),
max_model_len
=
max_model_len
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
if
type
==
"cache_config"
:
if
type
==
"cache_config"
:
...
@@ -158,34 +231,66 @@ class StatLogger:
...
@@ -158,34 +231,66 @@ class StatLogger:
return
elapsed_time
>
self
.
local_interval
return
elapsed_time
>
self
.
local_interval
def
_log_prometheus
(
self
,
stats
:
Stats
)
->
None
:
def
_log_prometheus
(
self
,
stats
:
Stats
)
->
None
:
# Set system stat gauges.
# System state data
self
.
metrics
.
gauge_scheduler_running
.
labels
(
**
self
.
labels
).
set
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_scheduler_running
,
stats
.
num_running
)
stats
.
num_running_sys
)
self
.
metrics
.
gauge_scheduler_swapped
.
labels
(
**
self
.
labels
).
set
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_scheduler_swapped
,
stats
.
num_swapped
)
stats
.
num_swapped_sys
)
self
.
metrics
.
gauge_scheduler_waiting
.
labels
(
**
self
.
labels
).
set
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_scheduler_waiting
,
stats
.
num_waiting
)
stats
.
num_waiting_sys
)
self
.
metrics
.
gauge_gpu_cache_usage
.
labels
(
**
self
.
labels
).
set
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_gpu_cache_usage
,
stats
.
gpu_cache_usage
)
stats
.
gpu_cache_usage_sys
)
self
.
metrics
.
gauge_cpu_cache_usage
.
labels
(
**
self
.
labels
).
set
(
self
.
_log_gauge
(
self
.
metrics
.
gauge_cpu_cache_usage
,
stats
.
cpu_cache_usage
)
stats
.
cpu_cache_usage_sys
)
# Add to token counters.
# Iteration level data
self
.
metrics
.
counter_prompt_tokens
.
labels
(
**
self
.
labels
).
inc
(
self
.
_log_counter
(
self
.
metrics
.
counter_prompt_tokens
,
stats
.
num_prompt_tokens
)
stats
.
num_prompt_tokens_iter
)
self
.
metrics
.
counter_generation_tokens
.
labels
(
**
self
.
labels
).
inc
(
self
.
_log_counter
(
self
.
metrics
.
counter_generation_tokens
,
stats
.
num_generation_tokens
)
stats
.
num_generation_tokens_iter
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_to_first_token
,
# Observe request level latencies in histograms.
stats
.
time_to_first_tokens_iter
)
for
ttft
in
stats
.
time_to_first_tokens
:
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_per_output_token
,
self
.
metrics
.
histogram_time_to_first_token
.
labels
(
stats
.
time_per_output_tokens_iter
)
**
self
.
labels
).
observe
(
ttft
)
for
tpot
in
stats
.
time_per_output_tokens
:
# Request level data
self
.
metrics
.
histogram_time_per_output_token
.
labels
(
# Latency
**
self
.
labels
).
observe
(
tpot
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_e2e_time_request
,
for
e2e
in
stats
.
time_e2e_requests
:
stats
.
time_e2e_requests
)
self
.
metrics
.
histogram_e2e_request_latency
.
labels
(
# Metadata
**
self
.
labels
).
observe
(
e2e
)
finished_reason_counter
=
CollectionsCounter
(
stats
.
finished_reason_requests
)
self
.
_log_counter_labels
(
self
.
metrics
.
counter_request_success
,
finished_reason_counter
,
Metrics
.
labelname_finish_reason
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_num_prompt_tokens_request
,
stats
.
num_prompt_tokens_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_num_generation_tokens_request
,
stats
.
num_generation_tokens_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_n_request
,
stats
.
n_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_best_of_request
,
stats
.
best_of_requests
)
def
_log_gauge
(
self
,
gauge
:
Gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
def
_log_counter
(
self
,
counter
:
Counter
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to counter.
counter
.
labels
(
**
self
.
labels
).
inc
(
data
)
def
_log_counter_labels
(
self
,
counter
:
Counter
,
data
:
CollectionsCounter
,
label_key
:
str
)
->
None
:
# Convenience function for collection counter of labels.
for
label
,
count
in
data
.
items
():
counter
.
labels
(
**
{
**
self
.
labels
,
label_key
:
label
}).
inc
(
count
)
def
_log_histogram
(
self
,
histogram
:
Histogram
,
data
:
Union
[
List
[
int
],
List
[
float
]])
->
None
:
# Convenience function for logging list to histogram.
for
datum
in
data
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
datum
)
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
generation_throughput
:
float
)
->
None
:
...
@@ -210,8 +315,8 @@ class StatLogger:
...
@@ -210,8 +315,8 @@ class StatLogger:
self
.
_log_prometheus
(
stats
)
self
.
_log_prometheus
(
stats
)
# Save tracked stats for token counters.
# Save tracked stats for token counters.
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens
)
self
.
num_prompt_tokens
.
append
(
stats
.
num_prompt_tokens
_iter
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens
)
self
.
num_generation_tokens
.
append
(
stats
.
num_generation_tokens
_iter
)
# Log locally every local_interval seconds.
# Log locally every local_interval seconds.
if
self
.
_local_interval_elapsed
(
stats
.
now
):
if
self
.
_local_interval_elapsed
(
stats
.
now
):
...
@@ -227,14 +332,19 @@ class StatLogger:
...
@@ -227,14 +332,19 @@ class StatLogger:
# Log to stdout.
# Log to stdout.
logger
.
info
(
logger
.
info
(
f
"Avg prompt throughput:
{
prompt_throughput
:.
1
f
}
tokens/s, "
"Avg prompt throughput: %.1f tokens/s, "
f
"Avg generation throughput: "
"Avg generation throughput: %.1f tokens/s, "
f
"
{
generation_throughput
:.
1
f
}
tokens/s, "
"Running: %d reqs, Swapped: %d reqs, "
f
"Running:
{
stats
.
num_running
}
reqs, "
"Pending: %d reqs, GPU KV cache usage: %.1f%%, "
f
"Swapped:
{
stats
.
num_swapped
}
reqs, "
"CPU KV cache usage: %.1f%%"
,
f
"Pending:
{
stats
.
num_waiting
}
reqs, "
prompt_throughput
,
f
"GPU KV cache usage:
{
stats
.
gpu_cache_usage
*
100
:.
1
f
}
%, "
generation_throughput
,
f
"CPU KV cache usage:
{
stats
.
cpu_cache_usage
*
100
:.
1
f
}
%"
)
stats
.
num_running_sys
,
stats
.
num_swapped_sys
,
stats
.
num_waiting_sys
,
stats
.
gpu_cache_usage_sys
*
100
,
stats
.
cpu_cache_usage_sys
*
100
,
)
# Reset tracked stats for next interval.
# Reset tracked stats for next interval.
self
.
num_prompt_tokens
=
[]
self
.
num_prompt_tokens
=
[]
...
...
vllm/engine/output_processor/interfaces.py
View file @
1591c68f
...
@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
...
@@ -68,3 +68,9 @@ class SequenceGroupOutputProcessor(ABC):
scheduler.
scheduler.
"""
"""
pass
pass
@
abstractmethod
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Update prompt logprobs received from outputs to seq_group."""
pass
vllm/engine/output_processor/multi_step.py
View file @
1591c68f
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
from
transformers
import
PreTrainedTokenizer
from
transformers
import
PreTrainedTokenizer
...
@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
...
@@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -44,6 +45,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -44,6 +45,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
self
.
stop_checker
=
stop_checker
self
.
stop_checker
=
stop_checker
def
process_prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
# TODO(sang): Prompt logprob currently not implemented in multi step
# workers.
self
.
_log_prompt_logprob_unsupported_warning_once
()
@
staticmethod
@
functools
.
lru_cache
()
def
_log_prompt_logprob_unsupported_warning_once
():
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
def
process_outputs
(
self
,
sequence_group
:
SequenceGroup
,
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
outputs
:
List
[
SequenceGroupOutput
])
->
None
:
"""Append new tokens in the outputs to sequences in the sequence group.
"""Append new tokens in the outputs to sequences in the sequence group.
...
@@ -80,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -80,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
valid_samples
:
List
[
SequenceOutput
],
valid_samples
:
List
[
SequenceOutput
],
sampling_params
:
SamplingParams
)
->
None
:
sampling_params
:
SamplingParams
)
->
None
:
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_token_ids
=
[
sample
.
output_token
for
sample
in
valid_samples
]
output_logprobs
=
[
sample
.
logprobs
for
sample
in
valid_samples
]
# Truncate to max_tokens if necessary.
# Truncate to max_tokens if necessary.
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
...
@@ -104,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -104,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Incrementally append tokens to the sequence, as if we had only one new
# Incrementally append tokens to the sequence, as if we had only one new
# token.
# token.
for
output_token_id
in
output_token_ids
:
for
output_token_id
,
output_logprob
in
zip
(
output_token_ids
,
output_logprobs
):
seq
.
append_token_id
(
seq
.
append_token_id
(
token_id
=
output_token_id
,
token_id
=
output_token_id
,
# TODO emit logprobs in multi-step decoding.
logprobs
=
output_logprob
,
logprobs
=
{
output_token_id
:
Logprob
(
0.0
)},
)
)
new_char_count
=
0
new_char_count
=
0
...
...
vllm/engine/output_processor/single_step.py
View file @
1591c68f
...
@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -55,17 +55,23 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
),
f
"
{
type
(
self
)
}
does not support multiple outputs per step"
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
return
self
.
_process_sequence_group_outputs
(
sequence_group
,
outputs
[
0
])
def
_
process_
sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
process_
prompt_logprob
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
outputs
:
List
[
SequenceGroupOutput
]
)
->
None
:
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
# Process prompt logprobs
output
=
outputs
[
0
]
prompt_logprobs
=
output
s
.
prompt_logprobs
prompt_logprobs
=
output
.
prompt_logprobs
if
prompt_logprobs
is
not
None
and
\
if
(
prompt_logprobs
is
not
None
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
:
and
seq_group
.
sampling_params
.
detokenize
and
self
.
detokenizer
)
:
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
self
.
detokenizer
.
decode_prompt_logprobs_inplace
(
seq_group
,
prompt_logprobs
)
seq_group
,
prompt_logprobs
)
seq_group
.
prompt_logprobs
=
prompt_logprobs
if
not
seq_group
.
prompt_logprobs
:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
seq_group
.
prompt_logprobs
=
[
None
]
seq_group
.
prompt_logprobs
.
extend
(
prompt_logprobs
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
)
->
None
:
# Process samples
# Process samples
samples
=
outputs
.
samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
...
...
vllm/engine/output_processor/util.py
View file @
1591c68f
from
typing
import
List
from
typing
import
List
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
sampler_outputs
:
List
[
SamplerOutput
],
def
create_output_by_sequence_group
(
num_seq_groups
:
int
):
sampler_outputs
:
List
[
SamplerOutput
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
[step][sequence group] into [sequence group][step].
"""
"""
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment