Unverified Commit 4ae0969c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move status check in the memory pool to CPU (#1557)

parent 317631ca
...@@ -19,6 +19,7 @@ import logging ...@@ -19,6 +19,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC): ...@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
else: else:
self.store_dtype = dtype self.store_dtype = dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens. self.free_slots = None
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512
self.can_use_mem_size = self.size
self.clear() self.clear()
def available_size(self): def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer) return len(self.free_slots)
def alloc(self, need_size: int): def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer) if need_size > len(self.free_slots):
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size:
return None return None
self.mem_state[select_index] = False select_index = self.free_slots[:need_size]
self.can_use_mem_size -= len(select_index) self.free_slots = self.free_slots[need_size:]
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index return torch.tensor(select_index, dtype=torch.int32, device="cuda")
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
self.can_use_mem_size += len(free_index)
def clear(self): def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1)
self.mem_state.fill_(True)
self.can_use_mem_size = self.size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False
@abstractmethod @abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
...@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_num: int, head_num: int,
head_dim: int, head_dim: int,
layer_num: int, layer_num: int,
device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype)
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [ self.k_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" (size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" (size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
...@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
kv_lora_rank: int, kv_lora_rank: int,
qk_rope_head_dim: int, qk_rope_head_dim: int,
layer_num: int, layer_num: int,
device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [ self.kv_buffer = [
torch.empty( torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim), (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device="cuda", device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
......
...@@ -409,8 +409,11 @@ class ModelRunner: ...@@ -409,8 +409,11 @@ class ModelRunner:
4096, 4096,
) )
device = "cuda"
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda" max_num_reqs + 1,
self.model_config.context_len + 4,
device=device,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
...@@ -422,6 +425,7 @@ class ModelRunner: ...@@ -422,6 +425,7 @@ class ModelRunner:
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
...@@ -430,6 +434,7 @@ class ModelRunner: ...@@ -430,6 +434,7 @@ class ModelRunner:
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device,
) )
logger.info( logger.info(
f"Memory pool end. " f"Memory pool end. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment