Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ee14494
Unverified
Commit
5ee14494
authored
Mar 20, 2024
by
Woosuk Kwon
Committed by
GitHub
Mar 20, 2024
Browse files
[Misc] Remove cache stream and cache events (#3461)
parent
4ad521d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
32 deletions
+86
-32
tests/worker/test_swap.py
tests/worker/test_swap.py
+77
-0
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+8
-18
vllm/worker/worker.py
vllm/worker/worker.py
+1
-14
No files found.
tests/worker/test_swap.py
0 → 100644
View file @
5ee14494
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.worker.worker
import
Worker
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
def
test_swap
()
->
None
:
# Configure the engine.
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
,
dtype
=
"half"
,
load_format
=
"dummy"
)
(
model_config
,
cache_config
,
parallel_config
,
scheduler_config
,
device_config
,
_
)
=
engine_args
.
create_engine_configs
()
cache_config
.
num_gpu_blocks
=
100
cache_config
.
num_cpu_blocks
=
100
# Create the worker.
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
worker
=
Worker
(
model_config
=
model_config
,
parallel_config
=
parallel_config
,
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
True
,
)
# Initialize the worker.
worker
.
init_model
()
worker
.
load_model
()
worker
.
init_cache_engine
(
cache_config
)
worker
.
warm_up_model
()
# Randomly initialize the cache.
gpu_cache
=
worker
.
cache_engine
.
gpu_cache
cpu_cache
=
worker
.
cache_engine
.
cpu_cache
num_layers
=
len
(
gpu_cache
)
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
.
random_
()
gpu_value_cache
.
random_
()
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
.
random_
()
cpu_value_cache
.
random_
()
allclose
=
lambda
a
,
b
:
torch
.
allclose
(
a
.
cuda
(),
b
.
cuda
(),
rtol
=
0.0
,
atol
=
0.0
)
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
{},
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
{})
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_out
.
items
():
assert
allclose
(
gpu_key_cache
[
src
],
cpu_key_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
blocks_to_swap_in
=
{
19
:
45
,
67
:
23
,
12
:
78
,
40
:
99
,
1
:
71
}
worker
.
execute_model
(
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
blocks_to_swap_in
,
blocks_to_swap_out
=
{},
blocks_to_copy
=
{})
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_in
.
items
():
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
vllm/worker/cache_engine.py
View file @
5ee14494
...
@@ -38,7 +38,7 @@ class CacheEngine:
...
@@ -38,7 +38,7 @@ class CacheEngine:
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_gpu_blocks
=
cache_config
.
num_gpu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
self
.
num_cpu_blocks
=
cache_config
.
num_cpu_blocks
# Skip initializing
CUDA stream and buffer
for Neuron backend.
# Skip initializing
KV cache
for Neuron backend.
if
is_neuron
():
if
is_neuron
():
return
return
...
@@ -51,12 +51,6 @@ class CacheEngine:
...
@@ -51,12 +51,6 @@ class CacheEngine:
self
.
gpu_cache
=
self
.
allocate_gpu_cache
()
self
.
gpu_cache
=
self
.
allocate_gpu_cache
()
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
self
.
cpu_cache
=
self
.
allocate_cpu_cache
()
# Initialize the stream for caching operations.
self
.
cache_stream
=
torch
.
cuda
.
Stream
()
assert
self
.
cache_stream
!=
torch
.
cuda
.
current_stream
()
# Initialize the events for stream synchronization.
self
.
events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
self
.
num_layers
)]
def
get_key_block_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
def
get_key_block_shape
(
self
)
->
Tuple
[
int
,
int
,
int
,
int
]:
element_size
=
torch
.
tensor
([],
dtype
=
self
.
dtype
).
element_size
()
element_size
=
torch
.
tensor
([],
dtype
=
self
.
dtype
).
element_size
()
x
=
16
//
element_size
x
=
16
//
element_size
...
@@ -126,17 +120,13 @@ class CacheEngine:
...
@@ -126,17 +120,13 @@ class CacheEngine:
)
->
None
:
)
->
None
:
from
vllm._C
import
cache_ops
from
vllm._C
import
cache_ops
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
# Copy the key blocks.
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
cache_ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
# Copy the value blocks.
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
cache_ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
self
.
_swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
self
.
_swap
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
...
...
vllm/worker/worker.py
View file @
5ee14494
...
@@ -65,7 +65,6 @@ class Worker:
...
@@ -65,7 +65,6 @@ class Worker:
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
self
.
cache_engine
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
self
.
gpu_cache
=
None
def
init_model
(
self
,
cupy_port
:
Optional
[
int
]
=
None
)
->
None
:
def
init_model
(
self
,
cupy_port
:
Optional
[
int
]
=
None
)
->
None
:
...
@@ -148,7 +147,6 @@ class Worker:
...
@@ -148,7 +147,6 @@ class Worker:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
...
@@ -166,24 +164,13 @@ class Worker:
...
@@ -166,24 +164,13 @@ class Worker:
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
# Issue cache operations.
# Issue cache operations.
issued_cache_op
=
False
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if
blocks_to_swap_in
:
if
blocks_to_swap_in
:
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
self
.
cache_engine
.
swap_in
(
blocks_to_swap_in
)
issued_cache_op
=
True
if
blocks_to_swap_out
:
if
blocks_to_swap_out
:
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
self
.
cache_engine
.
swap_out
(
blocks_to_swap_out
)
issued_cache_op
=
True
if
blocks_to_copy
:
if
blocks_to_copy
:
self
.
cache_engine
.
copy
(
blocks_to_copy
)
self
.
cache_engine
.
copy
(
blocks_to_copy
)
issued_cache_op
=
True
cache_events
=
self
.
cache_events
if
issued_cache_op
else
None
# Wait for cache operations to finish.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if
cache_events
is
not
None
:
for
event
in
cache_events
:
event
.
wait
()
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment