Unverified Commit 15521495 authored by wangxiyu191's avatar wangxiyu191 Committed by GitHub
Browse files

refactor: Extract repeated member variables in KVCache subclasses to base class. (#6323)

parent ebe58d54
......@@ -94,6 +94,33 @@ class ReqToTokenPool:
class KVCache(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
......@@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache):
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self._create_buffers()
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.layer_transfer_counter = None
self.capture_mode = False
......@@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache):
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
with memory_saver_adapter.region():
with self.memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
......@@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache):
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
with memory_saver_adapter.region():
with self.memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
......@@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache):
for _ in range(layer_num)
]
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id - self.start_layer]
......@@ -742,7 +754,7 @@ class HostKVCache(abc.ABC):
def __init__(
self,
device_pool: MHATokenToKVPool,
device_pool: KVCache,
host_to_device_ratio: float,
host_size: int,
pin_memory: bool,
......@@ -914,6 +926,8 @@ class HostKVCache(abc.ABC):
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
def __init__(
self,
device_pool: MHATokenToKVPool,
......@@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache):
class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool
def __init__(
self,
device_pool: MLATokenToKVPool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment