"docs/source/en/api/pipelines/stable_diffusion/overview.mdx" did not exist on "ac3738462b1732193908b0fb7e557bedac3c57a5"
Commit bb59a3e7 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Fix cache engine

parent 5a309bb5
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
......@@ -14,34 +14,30 @@ class CacheEngine:
num_layers: int,
num_heads: int,
head_size: int,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
dtype: torch.dtype = torch.float16,
dtype: torch.dtype,
) -> None:
self.worker_id = worker_id
self.gpu_id = gpu_id
self.num_layers = num_layers
self.num_heads = num_heads
self.head_size = head_size
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.block_size = block_size
self.dtype = dtype
# Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()
# Initialize the streams.
self.copy_stream = torch.cuda.Stream(device=gpu_id)
self.swap_stream = torch.cuda.Stream(device=gpu_id)
assert self.copy_stream != self.swap_stream
current_stream = torch.cuda.current_stream(device=gpu_id)
assert self.copy_stream != current_stream
assert self.swap_stream != current_stream
# Initialize the events for synchronization.
# Initialize the stream for caching operations.
self.cache_stream = torch.cuda.Stream(device=gpu_id)
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
# Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def allocate_gpu_cache(self) -> List[List[KVCache]]:
gpu_cache: List[List[KVCache]] = []
......@@ -81,29 +77,14 @@ class CacheEngine:
cpu_cache.append(layer_cache)
return cpu_cache
def copy(
self,
src_block_numbers: List[int],
dst_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the COPY op.
def copy(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
def swap_out(
self,
gpu_block_numbers: List[int],
cpu_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the SWAP_OUT op on the swap stream.
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
def swap_in(
self,
gpu_block_numbers: List[int],
cpu_block_numbers: List[int],
) -> None:
for layer in range(self.num_layers):
# TODO: Call the SWAP_IN op on the swap stream.
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
for event in self.events:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment