Commit 6f058c7b authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Implement cache ops

parent a1c67e6d
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from cacheflow import ops
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -92,14 +93,30 @@ class CacheEngine: ...@@ -92,14 +93,30 @@ class CacheEngine:
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
def _copy_blocks(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
def copy(self, src_to_dst: Dict[int, int]) -> None: def copy(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events: self._copy_blocks(self.gpu_cache, self.gpu_cache, src_to_dst)
pass
def swap_in(self, src_to_dst: Dict[int, int]) -> None: def swap_in(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events: self._copy_blocks(self.cpu_cache, self.gpu_cache, src_to_dst)
pass
def swap_out(self, src_to_dst: Dict[int, int]) -> None: def swap_out(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events: self._copy_blocks(self.gpu_cache, self.cpu_cache, src_to_dst)
pass
#include <torch/extension.h>
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_cache_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
copy_blocks(src, dst, block_mapping);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"copy_cache_blocks",
&copy_cache_blocks,
"Copy the cache blocks from src to dst");
}
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <cassert>
#include <map>
void copy_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
assert(src_device.index() == dst_device.index());
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
assert(false);
}
void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}
import setuptools
from torch.utils import cpp_extension
CXX_FLAGS = ['-g']
NVCC_FLAGS = ['-O2']
ext_modules = []
# Cache operations.
cache_extension = cpp_extension.CUDAExtension(
name='cacheflow.ops',
sources=['csrc/cache.cpp', 'csrc/cache_kernel.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(cache_extension)
setuptools.setup(
name='cacheflow',
requires_python='>=3.9',
ext_modules=ext_modules,
cmdclass={'build_ext': cpp_extension.BuildExtension},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment