Unverified Commit 71d12195 authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[Kernel] correct cpu worker function parameter type (#19745)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent e384f2f1
......@@ -29,7 +29,7 @@ class _PagedAttention:
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
return 2, num_blocks, block_size * num_kv_heads * head_size
@staticmethod
def split_kv_cache(
......
......@@ -3,7 +3,7 @@
"""A CPU worker class."""
import os
from importlib import util
from typing import Dict, List, Optional, Set, Tuple, Type
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
......@@ -88,13 +88,13 @@ class CPUCacheEngine:
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
def swap_in(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
def swap_out(self, src_to_dst: torch.Tensor) -> None:
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment