Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf5f000d
Unverified
Commit
cf5f000d
authored
Jan 10, 2025
by
Chen Zhang
Committed by
GitHub
Jan 10, 2025
Browse files
[torch.compile] Hide KV cache behind torch.compile boundary (#11677)
Signed-off-by:
Chen Zhang
<
zhangch99@outlook.com
>
parent
3de2b1ea
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
198 additions
and
44 deletions
+198
-44
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+12
-6
tests/test_utils.py
tests/test_utils.py
+83
-2
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+3
-0
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+3
-0
vllm/attention/layer.py
vllm/attention/layer.py
+17
-12
vllm/config.py
vllm/config.py
+0
-1
vllm/forward_context.py
vllm/forward_context.py
+20
-13
vllm/utils.py
vllm/utils.py
+35
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+5
-1
vllm/worker/cpu_enc_dec_model_runner.py
vllm/worker/cpu_enc_dec_model_runner.py
+2
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+2
-1
vllm/worker/cpu_pooling_model_runner.py
vllm/worker/cpu_pooling_model_runner.py
+2
-1
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+3
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+2
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+3
-2
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+2
-1
vllm/worker/worker.py
vllm/worker/worker.py
+3
-1
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+1
-0
No files found.
tests/kernels/test_encoder_decoder_attn.py
View file @
cf5f000d
...
...
@@ -142,12 +142,18 @@ def _make_test_resources(test_pt: TestPoint, ) -> TestResources:
torch
.
tensor
([],
dtype
=
torch
.
float32
,
device
=
CUDA_DEVICE
))
# Construct KV cache
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
if
test_pt
.
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
):
kv_cache
=
make_kv_cache
(
test_pt
.
num_blocks
,
test_pt
.
num_heads
,
test_pt
.
head_size
,
test_pt
.
block_size
,
device
=
CUDA_DEVICE
,
backend
=
test_pt
.
backend_name
)
else
:
kv_cache
=
torch
.
tensor
([])
attn
.
kv_cache
=
[
kv_cache
]
return
TestResources
(
scale
,
attn
,
kv_cache
)
...
...
tests/test_utils.py
View file @
cf5f000d
...
...
@@ -7,9 +7,11 @@ import pytest
import
torch
from
vllm_test_utils
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
FlexibleArgumentParser
,
PlaceholderModule
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
memory_profiling
,
merge_async_iterators
,
supports_kw
)
StoreBoolean
,
bind_kv_cache
,
deprecate_kwargs
,
get_open_port
,
memory_profiling
,
merge_async_iterators
,
supports_kw
)
from
.utils
import
error_on_warning
,
fork_new_process_for_each_test
...
...
@@ -325,6 +327,85 @@ def test_memory_profiling():
lib
.
cudaFree
(
handle2
)
def
test_bind_kv_cache
():
from
vllm.attention
import
Attention
ctx
=
{
'layers.0.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.1.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.2.self_attn'
:
Attention
(
32
,
128
,
0.1
),
'layers.3.self_attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'layers.1.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
'layers.2.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
2
]
assert
ctx
[
'layers.3.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
3
]
def
test_bind_kv_cache_non_attention
():
from
vllm.attention
import
Attention
# example from Jamba PP=2
ctx
=
{
'model.layers.20.attn'
:
Attention
(
32
,
128
,
0.1
),
'model.layers.28.attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
torch
.
zeros
((
1
,
)),
]
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'model.layers.20.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'model.layers.28.attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
def
test_bind_kv_cache_encoder_decoder
():
from
vllm.attention
import
Attention
,
AttentionType
# example from bart
ctx
=
{
'encoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER
),
'decoder.layers.0.encoder_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER_DECODER
),
'decoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
DECODER
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
]
encoder_kv_cache
=
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
is
encoder_kv_cache
assert
ctx
[
'decoder.layers.0.encoder_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'decoder.layers.0.self_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
def
test_bind_kv_cache_pp
():
cfg
=
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
2
))
with
set_current_vllm_config
(
cfg
):
from
vllm.attention
import
Attention
ctx
=
{
'layers.0.self_attn'
:
Attention
(
32
,
128
,
0.1
),
}
kv_cache
=
[
[
torch
.
zeros
((
1
,
))],
[
torch
.
zeros
((
1
,
))]
]
bind_kv_cache
(
ctx
,
kv_cache
)
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
][
0
]
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
...
...
tests/v1/engine/test_engine_core.py
View file @
cf5f000d
...
...
@@ -4,6 +4,7 @@ import uuid
import
pytest
from
transformers
import
AutoTokenizer
from
tests.utils
import
fork_new_process_for_each_test
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
...
...
@@ -36,6 +37,7 @@ def make_request() -> EngineCoreRequest:
)
@
fork_new_process_for_each_test
def
test_engine_core
(
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -138,6 +140,7 @@ def test_engine_core(monkeypatch):
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
@
fork_new_process_for_each_test
def
test_engine_core_advanced_sampling
(
monkeypatch
):
"""
A basic end-to-end test to verify that the engine functions correctly
...
...
tests/v1/engine/test_engine_core_client.py
View file @
cf5f000d
...
...
@@ -6,6 +6,7 @@ from typing import Dict, List
import
pytest
from
transformers
import
AutoTokenizer
from
tests.utils
import
fork_new_process_for_each_test
from
vllm
import
SamplingParams
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
...
...
@@ -75,6 +76,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break
@
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"multiprocessing_mode"
,
[
True
,
False
])
def
test_engine_core_client
(
monkeypatch
,
multiprocessing_mode
:
bool
):
...
...
@@ -143,6 +145,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client
.
abort_requests
([
request
.
request_id
])
@
fork_new_process_for_each_test
@
pytest
.
mark
.
asyncio
async
def
test_engine_core_client_asyncio
(
monkeypatch
):
...
...
vllm/attention/layer.py
View file @
cf5f000d
...
...
@@ -121,6 +121,13 @@ class Attention(nn.Module):
compilation_config
.
static_forward_context
[
prefix
]
=
self
self
.
layer_name
=
prefix
self
.
attn_type
=
attn_type
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self
.
kv_cache
=
[
torch
.
tensor
([])
for
_
in
range
(
get_current_vllm_config
(
).
parallel_config
.
pipeline_parallel_size
)
]
def
forward
(
self
,
...
...
@@ -148,11 +155,11 @@ class Attention(nn.Module):
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
kv_cache
,
self
.
layer_name
)
query
,
key
,
value
,
output
,
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
kv_cache
,
self
.
layer_name
)
self
.
layer_name
)
def
extra_repr
(
self
)
->
str
:
s
=
f
"head_size=
{
self
.
impl
.
head_size
}
"
# type: ignore
...
...
@@ -230,12 +237,12 @@ def unified_attention(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
self
=
forward_context
.
static_forward_context
[
layer_name
]
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
)
...
...
@@ -244,7 +251,6 @@ def unified_attention_fake(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
...
...
@@ -253,7 +259,7 @@ def unified_attention_fake(
direct_register_custom_op
(
op_name
=
"unified_attention"
,
op_func
=
unified_attention
,
mutates_args
=
[
"kv_cache"
],
mutates_args
=
[],
fake_impl
=
unified_attention_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
...
...
@@ -264,12 +270,12 @@ def unified_attention_with_output(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
dynamic_forward_context
self
=
forward_context
.
static_forward_context
[
layer_name
]
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
query
,
key
,
value
,
...
...
@@ -285,7 +291,6 @@ def unified_attention_with_output_fake(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
...
...
@@ -294,7 +299,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op
(
op_name
=
"unified_attention_with_output"
,
op_func
=
unified_attention_with_output
,
mutates_args
=
[
"kv_cache"
,
"output"
],
mutates_args
=
[
"output"
],
fake_impl
=
unified_attention_with_output_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
vllm/config.py
View file @
cf5f000d
...
...
@@ -2780,7 +2780,6 @@ class CompilationConfig(BaseModel):
compilation_time
:
float
=
PrivateAttr
# Per-model forward context
# Mainly used to store attention cls
# Map from layer name to the attention cls
static_forward_context
:
Dict
[
str
,
Any
]
=
PrivateAttr
...
...
vllm/forward_context.py
View file @
cf5f000d
...
...
@@ -2,7 +2,7 @@ import time
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
import
torch
...
...
@@ -10,6 +10,9 @@ import vllm.envs as envs
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
logger
=
init_logger
(
__name__
)
track_batchsize
:
bool
=
envs
.
VLLM_LOG_BATCHSIZE_INTERVAL
>=
0
...
...
@@ -21,9 +24,12 @@ batchsize_forward_time: defaultdict = defaultdict(list)
@
dataclass
class
ForwardContext
:
static_forward_context
:
Dict
[
str
,
Any
]
# copy from vllm_config.compilation_config.static_forward_context
attn_layers
:
Dict
[
str
,
Any
]
# TODO: extend to support per-layer dynamic forward context
dynamic_forward_context
:
Any
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
...
@@ -38,34 +44,35 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
def
set_forward_context
(
context
:
Any
,
vllm_config
:
VllmConfig
):
def
set_forward_context
(
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
"""
global
forward_start_time
need_to_track_batchsize
=
track_batchsize
and
context
is
not
None
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
if
need_to_track_batchsize
:
forward_start_time
=
time
.
perf_counter
()
global
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
static_forward_context
=
vllm_config
.
compilation_config
.
static_forward_context
,
dynamic_forward_context
=
context
)
attn_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
)
try
:
yield
finally
:
global
batchsize_counter
global
last_logging_time
,
batchsize_logging_interval
if
need_to_track_batchsize
:
if
hasattr
(
context
,
"num_prefill_tokens"
):
if
hasattr
(
attn_metadata
,
"num_prefill_tokens"
):
# for v0 attention backends
batchsize
=
context
.
num_prefill_tokens
+
\
context
.
num_decode_tokens
batchsize
=
attn_metadata
.
num_prefill_tokens
+
\
attn_metadata
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
context
.
num_input_tokens
batchsize
=
attn_metadata
.
num_input_tokens
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
...
...
vllm/utils.py
View file @
cf5f000d
...
...
@@ -2138,3 +2138,38 @@ def get_mp_context():
_check_multiproc_method
()
mp_method
=
envs
.
VLLM_WORKER_MULTIPROC_METHOD
return
multiprocessing
.
get_context
(
mp_method
)
def
bind_kv_cache
(
ctx
:
Dict
[
str
,
Any
],
kv_cache
:
List
[
List
[
torch
.
Tensor
]],
# [virtual_engine][layer_index]
)
->
None
:
# Bind the kv_cache tensor to Attention modules, similar to
# ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)]
# Special things handled here:
# 1. Some models have non-attention layers, e.g., Jamba
# 2. Pipeline parallelism, each rank only has a subset of layers
# 3. Encoder attention has no kv cache
# 4. Encoder-decoder models, encoder-decoder attention and decoder-only
# attention of the same layer (e.g., bart's decoder.layers.1.self_attn
# and decoder.layers.1.encoder_attn) is mapped to the same kv cache
# tensor
from
vllm.attention
import
AttentionType
from
vllm.model_executor.models.utils
import
extract_layer_index
layer_need_kv_cache
=
[
layer_name
for
layer_name
in
ctx
if
ctx
[
layer_name
].
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
)
]
layer_index_sorted
=
sorted
(
set
(
extract_layer_index
(
layer_name
)
for
layer_name
in
layer_need_kv_cache
))
for
layer_name
in
layer_need_kv_cache
:
kv_cache_idx
=
layer_index_sorted
.
index
(
extract_layer_index
(
layer_name
))
forward_ctx
=
ctx
[
layer_name
]
assert
len
(
forward_ctx
.
kv_cache
)
==
len
(
kv_cache
)
for
ve
,
ve_kv_cache
in
enumerate
(
kv_cache
):
assert
forward_ctx
.
kv_cache
[
ve
].
numel
()
==
0
forward_ctx
.
kv_cache
[
ve
]
=
ve_kv_cache
[
kv_cache_idx
]
vllm/v1/worker/gpu_model_runner.py
View file @
cf5f000d
...
...
@@ -16,7 +16,8 @@ from vllm.model_executor.model_loader import get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.sampling_params
import
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
)
LayerBlockType
,
bind_kv_cache
,
cdiv
,
is_pin_memory_available
)
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
from
vllm.v1.engine.mm_input_mapper
import
MMInputMapperClient
...
...
@@ -860,3 +861,6 @@ class GPUModelRunner:
torch
.
zeros
(
kv_cache_shape
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
))
bind_kv_cache
(
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
[
self
.
kv_caches
])
vllm/worker/cpu_enc_dec_model_runner.py
View file @
cf5f000d
...
...
@@ -305,7 +305,8 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Compute the logits.
...
...
vllm/worker/cpu_model_runner.py
View file @
cf5f000d
...
...
@@ -526,7 +526,8 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
execute_model_kwargs
.
update
(
{
"previous_hidden_states"
:
previous_hidden_states
})
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/cpu_pooling_model_runner.py
View file @
cf5f000d
...
...
@@ -69,7 +69,8 @@ class CPUPoolingModelRunner(
intermediate_tensors
,
}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_states
=
model_executable
(
**
execute_model_kwargs
)
# Only perform pooling in the driver worker.
...
...
vllm/worker/cpu_worker.py
View file @
cf5f000d
...
...
@@ -13,7 +13,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
bind_kv_cache
from
vllm.worker.cpu_enc_dec_model_runner
import
CPUEncoderDecoderModelRunner
from
vllm.worker.cpu_model_runner
import
CPUModelRunner
,
CPUModelRunnerBase
from
vllm.worker.cpu_pooling_model_runner
import
CPUPoolingModelRunner
...
...
@@ -293,6 +293,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self
.
cache_engine
[
ve
].
cpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
cpu_cache
)
self
.
model_runner
.
block_size
=
self
.
cache_engine
[
0
].
block_size
assert
all
(
...
...
vllm/worker/enc_dec_model_runner.py
View file @
cf5f000d
...
...
@@ -175,7 +175,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
}
if
self
.
has_inner_state
else
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
model_input
.
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
cf5f000d
...
...
@@ -1527,7 +1527,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
graph_runner
.
capture
(
**
capture_inputs
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
...
...
@@ -1695,7 +1696,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
not
bypass_model_exec
:
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/pooling_model_runner.py
View file @
cf5f000d
...
...
@@ -105,7 +105,8 @@ class PoolingModelRunner(
if
model_input
.
token_types
is
not
None
:
cross_enc_kwargs
[
"token_type_ids"
]
=
model_input
.
token_types
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
,
virtual_engine
):
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/worker.py
View file @
cf5f000d
...
...
@@ -21,7 +21,7 @@ from vllm.platforms import current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.utils
import
GiB_bytes
,
memory_profiling
from
vllm.utils
import
GiB_bytes
,
bind_kv_cache
,
memory_profiling
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
GPUModelRunnerBase
,
ModelRunner
...
...
@@ -285,6 +285,8 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
cache_engine
[
ve
].
gpu_cache
for
ve
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
bind_kv_cache
(
self
.
compilation_config
.
static_forward_context
,
self
.
gpu_cache
)
def
_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
...
...
vllm/worker/worker_base.py
View file @
cf5f000d
...
...
@@ -43,6 +43,7 @@ class WorkerBase(ABC):
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
kv_transfer_config
=
vllm_config
.
kv_transfer_config
self
.
compilation_config
=
vllm_config
.
compilation_config
from
vllm.platforms
import
current_platform
self
.
current_platform
=
current_platform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment