Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
6f058c7b
Commit
6f058c7b
authored
Feb 16, 2023
by
Woosuk Kwon
Browse files
Implement cache ops
parent
a1c67e6d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
6 deletions
+109
-6
cacheflow/worker/cache_engine.py
cacheflow/worker/cache_engine.py
+23
-6
csrc/cache.cpp
csrc/cache.cpp
+20
-0
csrc/cache_kernel.cu
csrc/cache_kernel.cu
+43
-0
setup.py
setup.py
+23
-0
No files found.
cacheflow/worker/cache_engine.py
View file @
6f058c7b
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Tuple
import
torch
import
torch
from
cacheflow
import
ops
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -92,14 +93,30 @@ class CacheEngine:
...
@@ -92,14 +93,30 @@ class CacheEngine:
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
cpu_cache
.
append
((
key_blocks
,
value_blocks
))
return
cpu_cache
return
cpu_cache
def
_copy_blocks
(
self
,
src
:
List
[
KVCache
],
dst
:
List
[
KVCache
],
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
cache_stream
):
for
i
in
range
(
self
.
num_layers
):
src_key_cache
,
src_value_cache
=
src
[
i
]
dst_key_cache
,
dst_value_cache
=
dst
[
i
]
# Copy the key blocks.
ops
.
copy_cache_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
# Copy the value blocks.
ops
.
copy_cache_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
event
=
self
.
events
[
i
]
event
.
record
(
stream
=
self
.
cache_stream
)
def
copy
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
copy
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
for
event
in
self
.
events
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
pass
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
for
event
in
self
.
events
:
self
.
_copy_blocks
(
self
.
cpu_cache
,
self
.
gpu_cache
,
src_to_dst
)
pass
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
])
->
None
:
for
event
in
self
.
events
:
self
.
_copy_blocks
(
self
.
gpu_cache
,
self
.
cpu_cache
,
src_to_dst
)
pass
csrc/cache.cpp
0 → 100644
View file @
6f058c7b
#include <torch/extension.h>
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
);
void
copy_cache_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
copy_blocks
(
src
,
dst
,
block_mapping
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"copy_cache_blocks"
,
&
copy_cache_blocks
,
"Copy the cache blocks from src to dst"
);
}
csrc/cache_kernel.cu
0 → 100644
View file @
6f058c7b
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cassert>
#include <map>
void
copy_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
cudaMemcpyKind
memcpy_type
;
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cuda
())
{
assert
(
src_device
.
index
()
==
dst_device
.
index
());
memcpy_type
=
cudaMemcpyDeviceToDevice
;
}
else
if
(
src_device
.
is_cuda
()
&&
dst_device
.
is_cpu
())
{
memcpy_type
=
cudaMemcpyDeviceToHost
;
}
else
if
(
src_device
.
is_cpu
()
&&
dst_device
.
is_cuda
())
{
memcpy_type
=
cudaMemcpyHostToDevice
;
}
else
{
assert
(
false
);
}
void
*
src_ptr
=
src
.
data_ptr
();
void
*
dst_ptr
=
dst
.
data_ptr
();
const
int64_t
block_size_in_bytes
=
src
.
element_size
()
*
src
[
0
].
numel
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
for
(
const
auto
&
pair
:
block_mapping
)
{
int64_t
src_block_number
=
pair
.
first
;
int64_t
dst_block_number
=
pair
.
second
;
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
cudaMemcpyAsync
(
dst_ptr
+
dst_offset
,
src_ptr
+
src_offset
,
block_size_in_bytes
,
memcpy_type
,
stream
);
}
}
setup.py
View file @
6f058c7b
import
setuptools
from
torch.utils
import
cpp_extension
CXX_FLAGS
=
[
'-g'
]
NVCC_FLAGS
=
[
'-O2'
]
ext_modules
=
[]
# Cache operations.
cache_extension
=
cpp_extension
.
CUDAExtension
(
name
=
'cacheflow.ops'
,
sources
=
[
'csrc/cache.cpp'
,
'csrc/cache_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
CXX_FLAGS
,
'nvcc'
:
NVCC_FLAGS
},
)
ext_modules
.
append
(
cache_extension
)
setuptools
.
setup
(
name
=
'cacheflow'
,
requires_python
=
'>=3.9'
,
ext_modules
=
ext_modules
,
cmdclass
=
{
'build_ext'
:
cpp_extension
.
BuildExtension
},
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment