Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53076d70
Commit
53076d70
authored
Mar 24, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.2' into v0.8.2-ori
parents
322a0be6
9c5c81b0
Changes
219
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
514 additions
and
186 deletions
+514
-186
vllm/transformers_utils/tokenizer_group/__init__.py
vllm/transformers_utils/tokenizer_group/__init__.py
+1
-1
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
...ransformers_utils/tokenizer_group/base_tokenizer_group.py
+0
-2
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
...transformers_utils/tokenizer_group/ray_tokenizer_group.py
+2
-8
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+0
-2
vllm/utils.py
vllm/utils.py
+97
-10
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+41
-2
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+2
-2
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+7
-10
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+10
-10
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+95
-35
vllm/v1/core/sched/__init__.py
vllm/v1/core/sched/__init__.py
+0
-0
vllm/v1/core/sched/interface.py
vllm/v1/core/sched/interface.py
+139
-0
vllm/v1/core/sched/output.py
vllm/v1/core/sched/output.py
+0
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+6
-31
vllm/v1/core/sched/utils.py
vllm/v1/core/sched/utils.py
+22
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+37
-21
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+24
-21
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+3
-3
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+27
-16
vllm/v1/engine/parallel_sampling.py
vllm/v1/engine/parallel_sampling.py
+1
-12
No files found.
vllm/transformers_utils/tokenizer_group/__init__.py
View file @
53076d70
...
...
@@ -18,7 +18,7 @@ else:
def
init_tokenizer_from_configs
(
model_config
:
ModelConfig
,
scheduler_config
:
SchedulerConfig
,
parallel_config
:
ParallelConfig
,
lora_config
:
LoRAConfig
):
lora_config
:
Optional
[
LoRAConfig
]
):
init_kwargs
=
dict
(
tokenizer_id
=
model_config
.
tokenizer
,
enable_lora
=
bool
(
lora_config
),
max_num_seqs
=
scheduler_config
.
max_num_seqs
,
...
...
vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
View file @
53076d70
...
...
@@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
@
abstractmethod
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
...
...
@@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group."""
...
...
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
View file @
53076d70
...
...
@@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
...
...
@@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor
=
actor
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
actor
.
encode
.
remote
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
))
except
ActorDiedError
as
e
:
...
...
@@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor
=
self
.
_init_actor
()
try
:
ret
=
ray
.
get
(
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
actor
.
encode
.
remote
(
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
))
except
ActorDiedError
as
e
:
...
...
@@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
"""Encode a prompt using the tokenizer group.
...
...
@@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor
=
actor
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
@@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor
=
self
.
_init_actor
()
try
:
ret
=
await
actor
.
encode
.
remote
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
53076d70
...
...
@@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
def
encode
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
tokenizer
=
self
.
get_lora_tokenizer
(
lora_request
)
...
...
@@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
async
def
encode_async
(
self
,
prompt
:
str
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
add_special_tokens
:
Optional
[
bool
]
=
None
)
->
List
[
int
]:
tokenizer
=
await
self
.
get_lora_tokenizer_async
(
lora_request
)
...
...
vllm/utils.py
View file @
53076d70
...
...
@@ -153,6 +153,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8"
:
torch
.
uint8
,
"fp8_e4m3"
:
torch
.
uint8
,
"fp8_e5m2"
:
torch
.
uint8
,
"int8"
:
torch
.
int8
,
}
TORCH_DTYPE_TO_NUMPY_DTYPE
=
{
...
...
@@ -411,6 +412,11 @@ async def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
if
len
(
iterators
)
==
1
:
# Fast-path single iterator case.
async
for
item
in
iterators
[
0
]:
yield
0
,
item
return
loop
=
asyncio
.
get_running_loop
()
...
...
@@ -2142,20 +2148,53 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx
.
destroy
(
linger
=
0
)
def
_check_multiproc_method
():
if
(
cuda_is_initialized
()
and
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
!=
"spawn"
):
logger
.
warning
(
"CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information."
)
def
is_in_ray_actor
():
"""Check if we are in a Ray actor."""
try
:
import
ray
return
(
ray
.
is_initialized
()
and
ray
.
get_runtime_context
().
get_actor_id
()
is
not
None
)
except
ImportError
:
return
False
def
_maybe_force_spawn
():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
)
==
"spawn"
:
return
reason
=
None
if
cuda_is_initialized
():
reason
=
"CUDA is initialized"
elif
is_in_ray_actor
():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import
ray
os
.
environ
[
"RAY_ADDRESS"
]
=
ray
.
get_runtime_context
().
gcs_address
reason
=
"In a Ray actor and can only be spawned"
if
reason
is
not
None
:
logger
.
warning
(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s"
,
reason
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
def
get_mp_context
():
_check_multiproc_method
()
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
...
...
@@ -2355,3 +2394,51 @@ def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
obj
[
key1
]
=
v2
else
:
obj
.
pop
(
key1
,
None
)
@
contextlib
.
contextmanager
def
cprofile_context
(
save_file
:
Optional
[
str
]
=
None
):
"""Run a cprofile
Args:
save_file: path to save the profile result. "1" or
None will result in printing to stdout.
"""
import
cProfile
prof
=
cProfile
.
Profile
()
prof
.
enable
()
try
:
yield
finally
:
prof
.
disable
()
if
save_file
and
save_file
!=
"1"
:
prof
.
dump_stats
(
save_file
)
else
:
prof
.
print_stats
(
sort
=
"cumtime"
)
def
cprofile
(
save_file
:
Optional
[
str
]
=
None
,
enabled
:
bool
=
True
):
"""Decorator to profile a Python method using cProfile.
Args:
save_file: Path to save the profile result.
If "1", None, or "", results will be printed to stdout.
enabled: Set to false to turn this into a no-op
"""
def
decorator
(
func
:
Callable
):
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
if
not
enabled
:
# If profiling is disabled, just call the function directly.
return
func
(
*
args
,
**
kwargs
)
with
cprofile_context
(
save_file
):
return
func
(
*
args
,
**
kwargs
)
return
wrapper
return
decorator
vllm/v1/attention/backends/flash_attn.py
View file @
53076d70
...
...
@@ -6,17 +6,18 @@ from typing import TYPE_CHECKING, Any, Optional
import
numpy
as
np
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
if
TYPE_CHECKING
:
from
vllm.v1.core.sched
uler_
output
import
SchedulerOutput
from
vllm.v1.core.sched
.
output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
...
@@ -226,6 +227,9 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert
output
is
not
None
,
"Output tensor must be provided."
...
...
@@ -259,6 +263,17 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_v_scale
,
)
descale_shape
=
(
attn_metadata
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
num_tokens
,
num_heads
,
head_size
=
query
.
shape
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
...
...
@@ -279,6 +294,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
return
output
...
...
@@ -301,6 +319,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
return
output
...
...
@@ -391,6 +412,9 @@ def cascade_attention(
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
fa_version
:
int
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
alibi_slopes
is
None
,
(
"Cascade attention does not support ALiBi."
)
# TODO: Support sliding window.
...
...
@@ -402,6 +426,7 @@ def cascade_attention(
assert
common_prefix_len
%
block_size
==
0
num_common_kv_blocks
=
common_prefix_len
//
block_size
assert
num_common_kv_blocks
>
0
descale_shape
=
(
cu_prefix_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process shared prefix.
prefix_output
,
prefix_lse
=
flash_attn_varlen_func
(
...
...
@@ -419,8 +444,16 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
descale_shape
=
(
cu_query_lens
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
-
2
])
# Process suffix per query.
suffix_output
,
suffix_lse
=
flash_attn_varlen_func
(
q
=
query
,
...
...
@@ -437,6 +470,12 @@ def cascade_attention(
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
k_descale
=
k_descale
.
expand
(
descale_shape
)
if
k_descale
is
not
None
else
None
,
v_descale
=
v_descale
.
expand
(
descale_shape
)
if
v_descale
is
not
None
else
None
,
)
# Merge prefix and suffix outputs, and store the result in output.
...
...
vllm/v1/attention/backends/mla/common.py
View file @
53076d70
...
...
@@ -195,8 +195,8 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
...
...
@@ -212,7 +212,7 @@ except ImportError:
from
flash_attn
import
flash_attn_varlen_func
if
TYPE_CHECKING
:
from
vllm.v1.core.sched
uler_
output
import
SchedulerOutput
from
vllm.v1.core.sched
.
output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
...
...
vllm/v1/attention/backends/pallas.py
View file @
53076d70
...
...
@@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
block_size
,
num_kv_heads
*
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_blocks, block_size, num_kv_heads
,
head_size],
[num_blocks, block_size, num_kv_heads
,
head_size])
kv_cache = ([num_blocks, block_size, num_kv_heads
*
head_size],
[num_blocks, block_size, num_kv_heads
*
head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
...
...
@@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
num_tokens
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
num_tokens
,
self
.
num_kv_heads
,
self
.
head_size
)
key_cache
,
value_cache
=
kv_cache
if
kv_cache
[
0
].
numel
()
>
0
:
...
...
@@ -192,10 +190,10 @@ def write_to_kv_cache(
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads
,
head_size]
value: shape = [num_tokens, num_kv_heads
,
head_size]
k_cache = [num_blocks, block_size, num_kv_heads
,
head_size]
v_cache = [num_blocks, block_size, num_kv_heads
,
head_size]
key: shape = [num_tokens, num_kv_heads
*
head_size]
value: shape = [num_tokens, num_kv_heads
*
head_size]
k_cache = [num_blocks, block_size, num_kv_heads
*
head_size]
v_cache = [num_blocks, block_size, num_kv_heads
*
head_size]
"""
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
...
...
@@ -203,6 +201,5 @@ def write_to_kv_cache(
key_cache
=
key_cache
.
flatten
(
0
,
1
)
value_cache
=
value_cache
.
flatten
(
0
,
1
)
slot_mapping
=
slot_mapping
.
flatten
()
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
vllm/v1/attention/backends/
rocm
_attn.py
→
vllm/v1/attention/backends/
triton
_attn.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention
on rocm
"""
"""Attention layer with PagedAttention
and Triton prefix prefill.
"""
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import (
logger
=
init_logger
(
__name__
)
class
ROCm
AttentionBackend
(
AttentionBackend
):
class
Triton
AttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
...
...
@@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend):
@
staticmethod
def
get_name
()
->
str
:
return
"
ROCM
_ATTN_VLLM_V1"
return
"
TRITON
_ATTN_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
type
[
"
ROCm
AttentionImpl"
]:
return
ROCm
AttentionImpl
def
get_impl_cls
()
->
type
[
"
Triton
AttentionImpl"
]:
return
Triton
AttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
...
...
@@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend):
return
FlashAttentionMetadataBuilder
class
ROCm
AttentionImpl
(
AttentionImpl
):
class
Triton
AttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
...
...
@@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl):
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"
ROCm
Attention does not support block-sparse attention."
)
"
Triton
Attention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
...
...
@@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
ROCm
AttentionBackend
.
get_supported_head_sizes
()
support_head_sizes
=
Triton
AttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by
ROCm
Attention. "
f
"Head size
{
head_size
}
is not supported by
Triton
Attention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"
ROCm
AttentionImpl"
)
"
Triton
AttentionImpl"
)
def
forward
(
self
,
...
...
vllm/v1/core/kv_cache_utils.py
View file @
53076d70
...
...
@@ -7,8 +7,8 @@ from typing import Any, NamedTuple, Optional
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.v1.kv_cache_interface
import
(
KVCacheConfig
,
KVCache
Group
Spec
,
KVCacheSpec
,
KVCacheTensor
)
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
...
...
@@ -449,7 +449,7 @@ def hash_request_tokens(block_size: int,
def
check_enough_kv_cache_memory
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
]
,
available_memory
:
int
):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
...
...
@@ -457,7 +457,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
Raises:
...
...
@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
f
"`max_model_len` when initializing the engine."
)
def
is_kv_cache_type_uniform
(
kv_cache_spec
:
KVCacheSpec
)
->
bool
:
def
create_kv_cache_group_specs
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
grouped_layer_names
:
list
[
list
[
str
]])
->
list
[
KVCacheGroupSpec
]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups
=
[]
for
layer_names_one_group
in
grouped_layer_names
:
layer_spec
=
kv_cache_spec
[
layer_names_one_group
[
0
]]
assert
all
(
kv_cache_spec
[
layer_name
]
==
layer_spec
for
layer_name
in
layer_names_one_group
[
1
:]),
(
"All layers in the same KV cache group must share the same "
"KVCacheSpec."
)
kv_cache_groups
.
append
(
KVCacheGroupSpec
(
layer_names_one_group
,
layer_spec
))
return
kv_cache_groups
def
is_kv_cache_type_uniform
(
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
])
->
bool
:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The
KVC
ache
S
pec of the model
kv_cache_spec: The
kv c
ache
s
pec of
each attention layer in
the model
Returns:
True if all layers have the same type, False otherwise.
...
...
@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
def
_get_kv_cache_config_uniform_type
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
KVCacheSpec
,
available_memory
:
int
,
num_layers
:
int
)
->
KVCacheConfig
:
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
num_layers: The number of layers in the model.
Returns:
The generated KVCacheConfig
...
...
@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
assert
len
(
page_sizes
)
==
1
page_size
=
page_sizes
.
pop
()
num_blocks
=
int
(
available_memory
//
page_size
//
num_layers
)
num_blocks
=
int
(
available_memory
//
page_size
//
len
(
kv_cache_spec
)
)
num_blocks
=
max
(
num_blocks
,
0
)
if
vllm_config
.
cache_config
.
num_gpu_blocks_override
is
not
None
:
...
...
@@ -541,6 +570,9 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
max_model_len_str
,
max_concurrency
)
per_layer_size
=
page_size
*
num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names
=
[
list
(
kv_cache_spec
.
keys
())]
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_blocks
,
...
...
@@ -548,41 +580,69 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
layer_name
:
KVCacheTensor
(
size
=
per_layer_size
)
for
layer_name
in
kv_cache_spec
},
groups
=
[[
layer_name
for
layer_name
in
kv_cache_spec
]],
kv_cache_spec
=
kv_cache_spec
)
kv_cache_groups
=
create_kv_cache_group_specs
(
kv_cache_spec
,
grouped_layer_names
),
)
return
kv_cache_config
def
get_kv_cache_config
s
(
vllm_config
:
VllmConfig
,
kv_cache_spec
s
:
list
[
KVCacheSpec
],
available_memory
:
int
)
->
list
[
KVCacheConfig
]
:
def
get_kv_cache_config
(
vllm_config
:
VllmConfig
,
kv_cache_spec
:
dict
[
str
,
KVCacheSpec
],
available_memory
:
int
)
->
KVCacheConfig
:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Args:
vllm_config: The global VllmConfig
kv_cache_spec
s
: The kv cache spec
s
of the model
kv_cache_spec: The kv cache spec of
each attention layer in
the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
"""
# Use the max number of layers to conservatively determine
# the number of blocks.
num_layers
=
max
(
len
(
kv_cache_spec
)
for
kv_cache_spec
in
kv_cache_specs
)
kv_cache_configs
=
[]
for
kv_cache_spec
in
kv_cache_specs
:
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
kv_cache_configs
.
append
(
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
,
num_layers
))
else
:
raise
NotImplementedError
check_enough_kv_cache_memory
(
vllm_config
,
kv_cache_spec
,
available_memory
)
if
is_kv_cache_type_uniform
(
kv_cache_spec
):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return
_get_kv_cache_config_uniform_type
(
vllm_config
,
kv_cache_spec
,
available_memory
)
raise
NotImplementedError
def
unify_kv_cache_configs
(
kv_cache_configs
:
list
[
KVCacheConfig
]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for
kv_cache_config
in
kv_cache_configs
:
kv_cache_config
.
kv_cache_groups
.
sort
(
key
=
lambda
x
:
x
.
kv_cache_spec
.
type_id
)
# Verify that the groups of each rank are the same.
for
kv_cache_config
in
kv_cache_configs
[
1
:]:
for
group_rank_0
,
group_rank_i
in
zip
(
kv_cache_configs
[
0
].
kv_cache_groups
,
kv_cache_config
.
kv_cache_groups
):
assert
group_rank_0
.
kv_cache_spec
==
group_rank_i
.
kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks
=
min
(
kv_cache_config
.
num_blocks
for
kv_cache_config
in
kv_cache_configs
)
for
kv_cache_config
in
kv_cache_configs
:
kv_cache_config
.
num_blocks
=
min_num_blocks
return
kv_cache_configs
vllm/v1/core/sched/__init__.py
0 → 100644
View file @
53076d70
vllm/v1/core/sched/interface.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
EngineCoreOutputs
from
vllm.v1.metrics.stats
import
SchedulerStats
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
class
SchedulerInterface
(
ABC
):
@
abstractmethod
def
schedule
(
self
)
->
"SchedulerOutput"
:
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise
NotImplementedError
@
abstractmethod
def
update_from_output
(
self
,
scheduler_output
:
"SchedulerOutput"
,
model_runner_output
:
"ModelRunnerOutput"
,
)
->
"EngineCoreOutputs"
:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A EngineCoreOutputs object containing the outputs for each request.
"""
raise
NotImplementedError
@
abstractmethod
def
add_request
(
self
,
request
:
"Request"
)
->
None
:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise
NotImplementedError
@
abstractmethod
def
finish_requests
(
self
,
request_ids
:
Union
[
str
,
Iterable
[
str
]],
finished_status
:
"RequestStatus"
,
)
->
None
:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise
NotImplementedError
@
abstractmethod
def
get_num_unfinished_requests
(
self
)
->
int
:
"""Number of unfinished requests in the scheduler's internal queue."""
raise
NotImplementedError
def
has_unfinished_requests
(
self
)
->
bool
:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return
self
.
get_num_unfinished_requests
()
>
0
@
abstractmethod
def
has_finished_requests
(
self
)
->
bool
:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise
NotImplementedError
def
has_requests
(
self
)
->
bool
:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return
self
.
has_unfinished_requests
()
or
self
.
has_finished_requests
()
@
abstractmethod
def
get_num_unscheduled_requests
(
self
)
->
int
:
"""Number of requests that are not being processed by the executor."""
raise
NotImplementedError
@
abstractmethod
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise
NotImplementedError
@
abstractmethod
def
make_stats
(
self
)
->
Optional
[
"SchedulerStats"
]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise
NotImplementedError
vllm/v1/core/sched
uler_
output.py
→
vllm/v1/core/sched
/
output.py
View file @
53076d70
File moved
vllm/v1/core/scheduler.py
→
vllm/v1/core/
sched/
scheduler.py
View file @
53076d70
...
...
@@ -13,8 +13,10 @@ from vllm.logger import init_logger
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
compute_encoder_budget
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.core.sched.utils
import
check_stop
from
vllm.v1.engine
import
(
EngineCoreEventType
,
EngineCoreOutput
,
EngineCoreOutputs
)
from
vllm.v1.metrics.stats
import
SchedulerStats
...
...
@@ -25,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager
logger
=
init_logger
(
__name__
)
class
Scheduler
:
class
Scheduler
(
SchedulerInterface
)
:
def
__init__
(
self
,
...
...
@@ -602,7 +604,7 @@ class Scheduler:
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped
=
self
.
_
check_stop
(
request
)
stopped
=
check_stop
(
request
,
self
.
max_model_len
)
if
stopped
:
self
.
_free_request
(
request
)
break
...
...
@@ -648,25 +650,6 @@ class Scheduler:
scheduler_stats
=
self
.
make_stats
(),
)
def
_check_stop
(
self
,
request
:
Request
)
->
bool
:
if
(
request
.
num_tokens
>=
self
.
max_model_len
or
request
.
num_output_tokens
>=
request
.
max_tokens
):
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
return
True
sampling_params
=
request
.
sampling_params
last_token_id
=
request
.
output_token_ids
[
-
1
]
if
(
not
sampling_params
.
ignore_eos
and
last_token_id
==
request
.
eos_token_id
):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
request
.
stop_reason
=
last_token_id
return
True
return
False
def
add_request
(
self
,
request
:
Request
)
->
None
:
self
.
waiting
.
append
(
request
)
self
.
requests
[
request
.
request_id
]
=
request
...
...
@@ -715,17 +698,9 @@ class Scheduler:
def
get_num_unfinished_requests
(
self
)
->
int
:
return
len
(
self
.
waiting
)
+
len
(
self
.
running
)
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
get_num_unfinished_requests
()
>
0
def
has_finished_requests
(
self
)
->
bool
:
return
len
(
self
.
finished_req_ids
)
>
0
def
has_requests
(
self
):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return
self
.
has_unfinished_requests
()
or
self
.
has_finished_requests
()
def
get_num_unscheduled_requests
(
self
)
->
int
:
"""Number of requests that are not being processed by the executor."""
return
self
.
get_num_unfinished_requests
()
-
len
(
self
.
scheduled_req_ids
)
...
...
vllm/v1/core/sched/utils.py
0 → 100644
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
vllm.v1.request
import
Request
,
RequestStatus
def
check_stop
(
request
:
Request
,
max_model_len
:
int
)
->
bool
:
if
(
request
.
num_tokens
>=
max_model_len
or
request
.
num_output_tokens
>=
request
.
max_tokens
):
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
return
True
sampling_params
=
request
.
sampling_params
last_token_id
=
request
.
output_token_ids
[
-
1
]
if
(
not
sampling_params
.
ignore_eos
and
last_token_id
==
request
.
eos_token_id
):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
return
True
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
()):
request
.
status
=
RequestStatus
.
FINISHED_STOPPED
request
.
stop_reason
=
last_token_id
return
True
return
False
vllm/v1/engine/async_llm.py
View file @
53076d70
...
...
@@ -4,6 +4,7 @@ import asyncio
import
logging
import
os
from
collections.abc
import
AsyncGenerator
,
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
import
numpy
as
np
...
...
@@ -24,7 +25,8 @@ from vllm.sampling_params import RequestOutputKind, SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
cdiv
,
kill_process_tree
from
vllm.utils
import
Device
,
cdiv
,
kill_process_tree
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
...
...
@@ -177,33 +179,44 @@ class AsyncLLM(EngineClient):
)
->
asyncio
.
Queue
[
RequestOutput
]:
"""Add new request to the AsyncLLM."""
#
1)
Create a new output queue for the request.
# Create a new output queue for the request.
queue
:
asyncio
.
Queue
[
RequestOutput
]
=
asyncio
.
Queue
()
# 2) Fan out child requests (for n>1)
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
# Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
for
idx
in
range
(
n
):
if
parent_req
is
not
None
:
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
# 3) Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
if
n
==
1
:
await
self
.
_add_request
(
request
,
None
,
0
,
queue
)
return
queue
# 4) Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
,
queue
)
# Fan out child requests (for n>1).
parent_request
=
ParentRequest
(
request_id
,
params
)
for
idx
in
range
(
n
):
request_id
,
params
=
parent_request
.
get_child_info
(
idx
)
child_request
=
request
if
idx
==
n
-
1
else
copy
(
request
)
child_request
.
request_id
=
request_id
child_request
.
sampling_params
=
params
await
self
.
_add_request
(
child_request
,
parent_request
,
idx
,
queue
)
return
queue
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
request
)
async
def
_add_request
(
self
,
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
],
index
:
int
,
queue
:
asyncio
.
Queue
[
RequestOutput
]):
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request_id
)
# Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
parent_req
,
index
,
queue
)
return
queue
# Add the EngineCoreRequest to EngineCore (separate process).
await
self
.
engine_core
.
add_request_async
(
request
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
# TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion
...
...
@@ -398,7 +411,10 @@ class AsyncLLM(EngineClient):
async
def
stop_profile
(
self
)
->
None
:
await
self
.
engine_core
.
profile_async
(
False
)
async
def
reset_prefix_cache
(
self
)
->
None
:
async
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
)
->
None
:
if
device
==
Device
.
CPU
:
raise
ValueError
(
"Not supported on CPU."
)
await
self
.
engine_core
.
reset_prefix_cache_async
()
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
...
...
vllm/v1/engine/core.py
View file @
53076d70
...
...
@@ -21,9 +21,10 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
(
get_exception_traceback
,
resolve_obj_by_qualname
,
zmq_socket_ctx
)
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_configs
from
vllm.v1.core.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
unify_kv_cache_configs
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
as
V1Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
,
UtilityOutput
)
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheServer
...
...
@@ -120,15 +121,27 @@ class EngineCore:
# memory can be allocated for kv cache.
available_gpu_memory
=
self
.
model_executor
.
determine_available_memory
()
assert
len
(
kv_cache_specs
)
==
len
(
available_gpu_memory
)
# Get the kv cache tensor size
kv_cache_configs
=
get_kv_cache_configs
(
vllm_config
,
kv_cache_specs
,
available_gpu_memory
)
num_gpu_blocks_set
=
set
(
config
.
num_blocks
for
config
in
kv_cache_configs
)
assert
len
(
num_gpu_blocks_set
)
==
1
,
(
f
"num_gpu_blocks need to be the same across workers, "
f
"but they are different:
{
num_gpu_blocks_set
}
"
)
num_gpu_blocks
=
num_gpu_blocks_set
.
pop
()
kv_cache_configs
=
[
get_kv_cache_config
(
vllm_config
,
kv_cache_spec_one_worker
,
available_gpu_memory_one_worker
)
for
kv_cache_spec_one_worker
,
available_gpu_memory_one_worker
in
zip
(
kv_cache_specs
,
available_gpu_memory
)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs
(
kv_cache_configs
)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
assert
all
([
cfg
.
num_blocks
==
kv_cache_configs
[
0
].
num_blocks
for
cfg
in
kv_cache_configs
])
num_gpu_blocks
=
kv_cache_configs
[
0
].
num_blocks
num_cpu_blocks
=
0
# Initialize kv cache and warmup the execution
...
...
@@ -179,16 +192,6 @@ class EngineCore:
scheduler_stats
=
self
.
scheduler
.
make_stats
(),
)
scheduler_output
=
self
.
scheduler
.
schedule
()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if
scheduler_output
.
total_num_scheduled_tokens
==
0
:
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
(),
)
output
=
self
.
model_executor
.
execute_model
(
scheduler_output
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
output
)
# type: ignore
...
...
vllm/v1/engine/core_client.py
View file @
53076d70
...
...
@@ -212,9 +212,9 @@ class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx
:
Union
[
zmq
.
Context
]
=
None
output_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
input_socket
:
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
=
None
ctx
:
zmq
.
Context
output_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
]
=
None
input_socket
:
Optional
[
Union
[
zmq
.
Socket
,
zmq
.
asyncio
.
Socket
]
]
=
None
proc_handle
:
Optional
[
BackgroundProcHandle
]
=
None
shutdown_path
:
Optional
[
str
]
=
None
...
...
vllm/v1/engine/llm_engine.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Mapping
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing_extensions
import
TypeVar
...
...
@@ -20,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
(
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
...
...
@@ -178,25 +180,34 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
# 1) Fan out child requests (for n>1)
parent_req
=
ParentRequest
.
from_params
(
request_id
,
params
)
# Process raw inputs into the request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_request
,
priority
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
for
idx
in
range
(
n
):
if
parent_req
is
not
None
:
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
# 2) Process raw inputs into the request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
arrival_time
,
lora_request
,
trace_headers
,
prompt_adapter_
request
,
priority
)
if
n
==
1
:
# Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
request
,
None
,
0
)
# Add the request to EngineCore.
self
.
engine_core
.
add_request
(
request
)
return
# 3) Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
request
,
parent_req
,
idx
)
# Fan out child requests (for n>1).
parent_req
=
ParentRequest
(
request_id
,
params
)
for
idx
in
range
(
n
):
request_id
,
params
=
parent_req
.
get_child_info
(
idx
)
child_request
=
request
if
idx
==
n
-
1
else
copy
(
request
)
child_request
.
request_id
=
request_id
child_request
.
sampling_params
=
params
# 3) Add the request to EngineCore.
self
.
engine_core
.
add_request
(
request
)
# Make a new RequestState and queue.
self
.
output_processor
.
add_request
(
child_request
,
parent_req
,
idx
)
# Add the request to EngineCore.
self
.
engine_core
.
add_request
(
child_request
)
def
step
(
self
)
->
list
[
RequestOutput
]:
...
...
@@ -226,7 +237,7 @@ class LLMEngine:
def
stop_profile
(
self
):
self
.
engine_core
.
profile
(
False
)
def
reset_prefix_cache
(
self
):
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
):
self
.
engine_core
.
reset_prefix_cache
()
def
sleep
(
self
,
level
:
int
=
1
):
...
...
vllm/v1/engine/parallel_sampling.py
View file @
53076d70
# SPDX-License-Identifier: Apache-2.0
from
copy
import
copy
from
typing
import
Optional
,
Union
from
typing
import
Optional
from
vllm.outputs
import
CompletionOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.v1.metrics.stats
import
IterationStats
...
...
@@ -43,16 +42,6 @@ class ParentRequest:
self
.
max_num_generation_tokens
=
0
self
.
cached_child_sampling_params
=
None
@
classmethod
def
from_params
(
cls
,
request_id
:
str
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
Optional
[
'ParentRequest'
]:
if
not
isinstance
(
params
,
SamplingParams
)
or
params
.
n
==
1
:
return
None
return
cls
(
request_id
,
params
)
def
_get_child_sampling_params
(
self
,
index
:
int
,
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment