Commit 9e68a682 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Fix return type error

parent 8edcabc7
...@@ -42,7 +42,7 @@ class CacheEngine: ...@@ -42,7 +42,7 @@ class CacheEngine:
# Initialize the events for stream synchronization. # Initialize the events for stream synchronization.
self.events = [torch.cuda.Event() for _ in range(self.num_layers)] self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
def get_key_block_shape(self) -> Tuple[int, int, int, int, int]: def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size() element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size x = 16 // element_size
return ( return (
...@@ -52,7 +52,7 @@ class CacheEngine: ...@@ -52,7 +52,7 @@ class CacheEngine:
x, x,
) )
def get_value_block_shape(self) -> Tuple[int, int, int, int]: def get_value_block_shape(self) -> Tuple[int, int, int]:
return ( return (
self.num_heads, self.num_heads,
self.block_size, self.block_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment