Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf5f000d
Unverified
Commit
cf5f000d
authored
Jan 10, 2025
by
Chen Zhang
Committed by
GitHub
Jan 10, 2025
Browse files
[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
3de2b1ea
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
198 additions
and
44 deletions
+198
-44
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+12
-6
tests/test_utils.py
tests/test_utils.py
+83
-2
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+3
-0
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+3
-0
vllm/attention/layer.py
vllm/attention/layer.py
+17
-12
vllm/config.py
vllm/config.py
+0
-1
vllm/forward_context.py
vllm/forward_context.py
+20
-13
vllm/utils.py
vllm/utils.py
+35
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-1
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+2
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-1
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+2
-1
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+3
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-2
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+2
-1
vllm/worker/worker.py
vllm/worker/worker.py
+3
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-0
No files found.
tests/kernels/test_encoder_decoder_attn.py
View file @
cf5f000d
...
@@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
...
@@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
# Construct KV cache
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
if
test_pt
.
attn_type
in
(
AttentionType
.
DECODER
,
test_pt
.
num_heads
,
AttentionType
.
ENCODER_DECODER
):
test_pt
.
head_size
,
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
block_size
,
test_pt
.
num_heads
,
device
=
CUDA_DEVICE
,
test_pt
.
head_size
,
backend
=
test_pt
.
backend_name
)
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
else
:
kv_cache
=
torch
.
tensor
([])
attn
.
kv_cache
=
[
kv_cache
]
return
TestResources
(
scale
,
attn
,
kv_cache
)
return
TestResources
(
scale
,
attn
,
kv_cache
)
...
...
tests/test_utils.py
View file @
cf5f000d
...
@@ -7,9 +7,11 @@ import pytest
...
@@ -7,9 +7,11 @@ import pytest
import
torch
import
torch
from
vllm_test_utils
import
monitor
from
vllm_test_utils
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
FlexibleArgumentParser
,
PlaceholderModule
,
from
vllm.utils
import
(
FlexibleArgumentParser
,
PlaceholderModule
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
StoreBoolean
,
bind_kv_cache
,
deprecate_kwargs
,
memory_profiling
,
merge_async_iterators
,
supports_kw
)
get_open_port
,
memory_profiling
,
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
...
@@ -325,6 +327,85 @@ def test_memory_profiling():
...
@@ -325,6 +327,85 @@ def test_memory_profiling():
lib
.
cudaFree
(
handle2
)
lib
.
cudaFree
(
handle2
)
def
test_bind_kv_cache
():
from
vllm.attention
import
Attention
ctx
=
{
'layers.0.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.1.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.2.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.3.self_attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'layers.1.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
'layers.2.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
2
]
assert
ctx
[
'layers.3.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
3
]
def
test_bind_kv_cache_non_attention
():
from
vllm.attention
import
Attention
# example from Jamba PP=2
ctx
=
{
'model.layers.20.attn'
:
Attention
(
32
,
128
,
0.1
),
'model.layers.28.attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'model.layers.20.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'model.layers.28.attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
def
test_bind_kv_cache_encoder_decoder
():
from
vllm.attention
import
Attention
,
AttentionType
# example from bart
ctx
=
{
'encoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER
),
'decoder.layers.0.encoder_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER_DECODER
),
'decoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
DECODER
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
]
encoder_kv_cache
=
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
is
encoder_kv_cache
assert
ctx
[
'decoder.layers.0.encoder_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'decoder.layers.0.self_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
def
test_bind_kv_cache_pp
():
cfg
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
2
))
with
set_current_vllm_config
(
cfg
):
from
vllm.attention
import
Attention
ctx
=
{
'layers.0.self_attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
[
torch
.
zeros
((
1
,
))],
[
torch
.
zeros
((
1
,
))]
]
bind_kv_cache
(
ctx
,
kv_cache
)
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
][
0
]
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
def
test_placeholder_module_error_handling
():
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
...
...
tests/v1/engine/test_engine_core.py
View file @
cf5f000d
...
@@ -4,6 +4,7 @@ import uuid
...
@@ -4,6 +4,7 @@ import uuid
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
tests.utils
import
fork_new_process_for_each_test
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
...
@@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
)
)
@
fork_new_process_for_each_test
def
test_engine_core
(
monkeypatch
):
def
test_engine_core
(
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
...
@@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
...
@@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
@
fork_new_process_for_each_test
def
test_engine_core_advanced_sampling
(
monkeypatch
):
def
test_engine_core_advanced_sampling
(
monkeypatch
):
"""
"""
A basic end-to-end test to verify that the engine functions correctly
A basic end-to-end test to verify that the engine functions correctly
...
...
tests/v1/engine/test_engine_core_client.py
View file @
cf5f000d
...
@@ -6,6 +6,7 @@ from typing import Dict, List
...
@@ -6,6 +6,7 @@ from typing import Dict, List
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
tests.utils
import
fork_new_process_for_each_test
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
...
@@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break
break
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode"
,
[
True
,
False
])
def
test_engine_core_client
(
monkeypatch
,
multiprocessing_mode
:
bool
):
def
test_engine_core_client
(
monkeypatch
,
multiprocessing_mode
:
bool
):
...
@@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
...
@@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client
.
abort_requests
([
request
.
request_id
])
client
.
abort_requests
([
request
.
request_id
])
@
fork_new_process_for_each_test
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_engine_core_client_asyncio
(
monkeypatch
):
async
def
test_engine_core_client_asyncio
(
monkeypatch
):
...
...
vllm/attention/layer.py
View file @
cf5f000d
...
@@ -121,6 +121,13 @@ class Attention(nn.Module):
...
@@ -121,6 +121,13 @@ class Attention(nn.Module):
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
get_current_vllm_config
(
).
parallel_config
.
pipeline_parallel_size
)
]
def
forward
(
def
forward
(
self
,
self
,
...
@@ -148,11 +155,11 @@ class Attention(nn.Module):
...
@@ -148,11 +155,11 @@ class Attention(nn.Module):
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
kv_cache
,
self
.
layer_name
)
query
,
key
,
value
,
output
,
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
kv_cache
,
self
.
layer_name
)
self
.
layer_name
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
@@ -230,12 +237,12 @@ def unified_attention(
...
@@ -230,12 +237,12 @@ def unified_attention(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
static_forward_context
[
layer_name
]
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
)
self
.
_k_scale
,
self
.
_v_scale
)
...
@@ -244,7 +251,6 @@ def unified_attention_fake(
...
@@ -244,7 +251,6 @@ def unified_attention_fake(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
return
torch
.
empty_like
(
query
).
contiguous
()
...
@@ -253,7 +259,7 @@ def unified_attention_fake(
...
@@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"unified_attention"
,
op_name
=
"unified_attention"
,
op_func
=
unified_attention
,
op_func
=
unified_attention
,
mutates_args
=
[
"kv_cache"
],
mutates_args
=
[],
fake_impl
=
unified_attention_fake
,
fake_impl
=
unified_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
...
@@ -264,12 +270,12 @@ def unified_attention_with_output(
...
@@ -264,12 +270,12 @@ def unified_attention_with_output(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
)
->
None
:
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
static_forward_context
[
layer_name
]
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
query
,
self
.
impl
.
forward
(
query
,
key
,
key
,
value
,
value
,
...
@@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
...
@@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
layer_name
:
str
,
)
->
None
:
)
->
None
:
return
return
...
@@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
...
@@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"unified_attention_with_output"
,
op_name
=
"unified_attention_with_output"
,
op_func
=
unified_attention_with_output
,
op_func
=
unified_attention_with_output
,
mutates_args
=
[
"kv_cache"
,
"output"
],
mutates_args
=
[
"output"
],
fake_impl
=
unified_attention_with_output_fake
,
fake_impl
=
unified_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
)
vllm/config.py
View file @
cf5f000d
...
@@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
...
@@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
compilation_time
:
float
=
PrivateAttr
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
# Map from layer name to the attention cls
static_forward_context
:
Dict
[
str
,
Any
]
=
PrivateAttr
static_forward_context
:
Dict
[
str
,
Any
]
=
PrivateAttr
...
...
vllm/forward_context.py
View file @
cf5f000d
...
@@ -2,7 +2,7 @@ import time
...
@@ -2,7 +2,7 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
import
torch
import
torch
...
@@ -10,6 +10,9 @@ import vllm.envs as envs
...
@@ -10,6 +10,9 @@ import vllm.envs as envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
...
@@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
...
@@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
@
dataclass
class
ForwardContext
:
class
ForwardContext
:
static_forward_context
:
Dict
[
str
,
Any
]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers
:
Dict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context
:
Any
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
_forward_context
:
Optional
[
ForwardContext
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
...
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
@
contextmanager
def
set_forward_context
(
context
:
Any
,
vllm_config
:
VllmConfig
):
def
set_forward_context
(
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Here we can inject common logic for every model forward pass.
"""
"""
global
forward_start_time
global
forward_start_time
need_to_track_batchsize
=
track_batchsize
and
context
is
not
None
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
if
need_to_track_batchsize
:
if
need_to_track_batchsize
:
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
_forward_context
=
ForwardContext
(
static_forward_context
=
vllm_config
.
compilation_config
.
attn_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
static_forward_context
,
virtual_engine
=
virtual_engine
,
dynamic_forward_context
=
context
)
attn_metadata
=
attn_metadata
)
try
:
try
:
yield
yield
finally
:
finally
:
global
batchsize_counter
global
last_logging_time
,
batchsize_logging_interval
global
last_logging_time
,
batchsize_logging_interval
if
need_to_track_batchsize
:
if
need_to_track_batchsize
:
if
hasattr
(
context
,
"num_prefill_tokens"
):
if
hasattr
(
attn_metadata
,
"num_prefill_tokens"
):
# for v0 attention backends
# for v0 attention backends
batchsize
=
context
.
num_prefill_tokens
+
\
batchsize
=
attn_metadata
.
num_prefill_tokens
+
\
context
.
num_decode_tokens
attn_metadata
.
num_decode_tokens
else
:
else
:
# for v1 attention backends
# for v1 attention backends
batchsize
=
context
.
num_input_tokens
batchsize
=
attn_metadata
.
num_input_tokens
# we use synchronous scheduling right now,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# adding a sync point here should not affect
# scheduling of the next batch
# scheduling of the next batch
...
...
vllm/utils.py
View file @
cf5f000d
...
@@ -2138,3 +2138,38 @@ def get_mp_context():
...
@@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method
()
_check_multiproc_method
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
return
multiprocessing
.
get_context
(
mp_method
)
def
bind_kv_cache
(
ctx
:
Dict
[
str
,
Any
],
kv_cache
:
List
[
List
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
)
->
None
:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from
vllm.attention
import
AttentionType
from
vllm.model_executor.models.utils
import
extract_layer_index
layer_need_kv_cache
=
[
layer_name
for
layer_name
in
ctx
if
ctx
[
layer_name
].
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
)
]
layer_index_sorted
=
sorted
(
set
(
extract_layer_index
(
layer_name
)
for
layer_name
in
layer_need_kv_cache
))
for
layer_name
in
layer_need_kv_cache
:
kv_cache_idx
=
layer_index_sorted
.
index
(
extract_layer_index
(
layer_name
))
forward_ctx
=
ctx
[
layer_name
]
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
assert
forward_ctx
.
kv_cache
[
ve
].
numel
()
==
0
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
vllm/v1/worker/gpu_model_runner.py
View file @
cf5f000d
...
@@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
...
@@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
LayerBlockType
,
bind_kv_cache
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
FlashAttentionMetadata
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
...
@@ -860,3 +861,6 @@ class GPUModelRunner:
...
@@ -860,3 +861,6 @@ class GPUModelRunner:
torch
.
zeros
(
kv_cache_shape
,
torch
.
zeros
(
kv_cache_shape
,
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
))
device
=
self
.
device
))
bind_kv_cache
(
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
[
self
.
kv_caches
])
vllm/worker/cpu_enc_dec_model_runner.py
View file @
cf5f000d
...
@@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
...
@@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors
,
intermediate_tensors
,
}
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
# Compute the logits.
...
...
vllm/worker/cpu_model_runner.py
View file @
cf5f000d
...
@@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
...
@@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
execute_model_kwargs
.
update
(
execute_model_kwargs
.
update
(
{
"previous_hidden_states"
:
previous_hidden_states
})
{
"previous_hidden_states"
:
previous_hidden_states
})
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/cpu_pooling_model_runner.py
View file @
cf5f000d
...
@@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
...
@@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
intermediate_tensors
,
intermediate_tensors
,
}
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Only perform pooling in the driver worker.
# Only perform pooling in the driver worker.
...
...
vllm/worker/cpu_worker.py
View file @
cf5f000d
...
@@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
...
@@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
...
@@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
...
@@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
cache_engine
[
ve
].
cpu_cache
self
.
cache_engine
[
ve
].
cpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
cpu_cache
)
self
.
model_runner
.
block_size
=
self
.
cache_engine
[
0
].
block_size
self
.
model_runner
.
block_size
=
self
.
cache_engine
[
0
].
block_size
assert
all
(
assert
all
(
...
...
vllm/worker/enc_dec_model_runner.py
View file @
cf5f000d
...
@@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
}
if
self
.
has_inner_state
else
{}
}
if
self
.
has_inner_state
else
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
cf5f000d
...
@@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
_update_inputs_to_capture_for_enc_dec_model
(
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
capture_inputs
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
graph_runner
.
capture
(
**
capture_inputs
)
graph_runner
.
capture
(
**
capture_inputs
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
...
@@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
bypass_model_exec
:
if
not
bypass_model_exec
:
with
set_forward_context
(
model_input
.
attn_metadata
,
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/pooling_model_runner.py
View file @
cf5f000d
...
@@ -105,7 +105,8 @@ class PoolingModelRunner(
...
@@ -105,7 +105,8 @@ class PoolingModelRunner(
if
model_input
.
token_types
is
not
None
:
if
model_input
.
token_types
is
not
None
:
cross_enc_kwargs
[
"token_type_ids"
]
=
model_input
.
token_types
cross_enc_kwargs
[
"token_type_ids"
]
=
model_input
.
token_types
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/worker.py
View file @
cf5f000d
...
@@ -21,7 +21,7 @@ from vllm.platforms import current_platform
...
@@ -21,7 +21,7 @@ from vllm.platforms import current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.utils
import
GiB_bytes
,
memory_profiling
from
vllm.utils
import
GiB_bytes
,
bind_kv_cache
,
memory_profiling
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
...
@@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
cache_engine
[
ve
].
gpu_cache
self
.
cache_engine
[
ve
].
gpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
gpu_cache
)
def
_warm_up_model
(
self
)
->
None
:
def
_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
if
not
self
.
model_config
.
enforce_eager
:
...
...
vllm/worker/worker_base.py
View file @
cf5f000d
...
@@ -43,6 +43,7 @@ class WorkerBase(ABC):
...
@@ -43,6 +43,7 @@ class WorkerBase(ABC):
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
compilation_config
=
vllm_config
.
compilation_config
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
self
.
current_platform
=
current_platform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment