Commit bb59a3e7 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Fix cache engine

parent 5a309bb5
from typing import List, Tuple from typing import Dict, List, Tuple
import torch import torch
...@@ -14,34 +14,30 @@ class CacheEngine: ...@@ -14,34 +14,30 @@ class CacheEngine:
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
block_size: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
block_size: int, dtype: torch.dtype,
dtype: torch.dtype = torch.float16,
) -> None: ) -> None:
self.worker_id = worker_id self.worker_id = worker_id
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.num_layers = num_layers self.num_layers = num_layers
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks self.num_cpu_blocks = num_cpu_blocks
self.block_size = block_size
self.dtype = dtype self.dtype = dtype
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache() self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache() self.cpu_cache = self.allocate_cpu_cache()
# Initialize the streams. # Initialize the stream for caching operations.
self.copy_stream = torch.cuda.Stream(device=gpu_id) self.cache_stream = torch.cuda.Stream(device=gpu_id)
self.swap_stream = torch.cuda.Stream(device=gpu_id) assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
assert self.copy_stream != self.swap_stream # Initialize the events for stream synchronization.
current_stream = torch.cuda.current_stream(device=gpu_id) self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
assert self.copy_stream != current_stream
assert self.swap_stream != current_stream
# Initialize the events for synchronization.
def allocate_gpu_cache(self) -> List[List[KVCache]]: def allocate_gpu_cache(self) -> List[List[KVCache]]:
gpu_cache: List[List[KVCache]] = [] gpu_cache: List[List[KVCache]] = []
...@@ -81,29 +77,14 @@ class CacheEngine: ...@@ -81,29 +77,14 @@ class CacheEngine:
cpu_cache.append(layer_cache) cpu_cache.append(layer_cache)
return cpu_cache return cpu_cache
def copy( def copy(self, src_to_dst: Dict[int, int]) -> None:
self, for event in self.events:
src_block_numbers: List[int],
dst_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the COPY op.
pass pass
def swap_out( def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self, for event in self.events:
gpu_block_numbers: List[int],
cpu_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the SWAP_OUT op on the swap stream.
pass pass
def swap_in( def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self, for event in self.events:
gpu_block_numbers: List[int],
cpu_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the SWAP_IN op on the swap stream.
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment