Unverified Commit 10143e1a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Memorypool chunked prefetch (#614)

parent 65c65776
...@@ -141,12 +141,5 @@ class RadixAttention(nn.Module): ...@@ -141,12 +141,5 @@ class RadixAttention(nn.Module):
if input_metadata.out_cache_loc is not None: if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v value_buffer[input_metadata.out_cache_loc] = cache_v
elif input_metadata.out_cache_cont_start is not None:
key_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_k
value_buffer[
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
] = cache_v
else: else:
raise RuntimeError() raise RuntimeError()
...@@ -104,8 +104,6 @@ class CudaGraphRunner: ...@@ -104,8 +104,6 @@ class CudaGraphRunner:
prefix_lens=None, prefix_lens=None,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_logprob=False, return_logprob=False,
top_logprobs_nums=0, top_logprobs_nums=0,
skip_flashinfer_init=True, skip_flashinfer_init=True,
......
...@@ -275,8 +275,6 @@ class Batch: ...@@ -275,8 +275,6 @@ class Batch:
prefix_lens: torch.Tensor = None prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: int = None
out_cache_cont_end: int = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
...@@ -566,8 +564,6 @@ class Batch: ...@@ -566,8 +564,6 @@ class Batch:
# Alloc mem # Alloc mem
bs = len(self.reqs) bs = len(self.reqs)
alloc_res = self.token_to_kv_pool.alloc_contiguous(bs)
if alloc_res is None:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs) self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None: if self.out_cache_loc is None:
...@@ -575,13 +571,6 @@ class Batch: ...@@ -575,13 +571,6 @@ class Batch:
self.tree_cache.pretty_print() self.tree_cache.pretty_print()
exit() exit()
self.out_cache_cont_start = None
self.out_cache_cont_end = None
else:
self.out_cache_loc = alloc_res[0]
self.out_cache_cont_start = alloc_res[1]
self.out_cache_cont_end = alloc_res[2]
self.req_to_token_pool.req_to_token[ self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 1 self.req_pool_indices, self.seq_lens - 1
] = self.out_cache_loc ] = self.out_cache_loc
...@@ -594,7 +583,7 @@ class Batch: ...@@ -594,7 +583,7 @@ class Batch:
self.req_pool_indices = self.req_pool_indices[new_indices] self.req_pool_indices = self.req_pool_indices[new_indices]
self.prefix_lens = None self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = None
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices] self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
...@@ -622,7 +611,7 @@ class Batch: ...@@ -622,7 +611,7 @@ class Batch:
self.position_ids_offsets = torch.concat( self.position_ids_offsets = torch.concat(
[self.position_ids_offsets, other.position_ids_offsets] [self.position_ids_offsets, other.position_ids_offsets]
) )
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
...@@ -729,8 +718,6 @@ class InputMetadata: ...@@ -729,8 +718,6 @@ class InputMetadata:
# Output location of the KV cache # Output location of the KV cache
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: int = None
out_cache_cont_end: int = None
# Output options # Output options
return_logprob: bool = False return_logprob: bool = False
...@@ -757,8 +744,6 @@ class InputMetadata: ...@@ -757,8 +744,6 @@ class InputMetadata:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
top_logprobs_nums=None, top_logprobs_nums=None,
return_logprob=False, return_logprob=False,
skip_flashinfer_init=False, skip_flashinfer_init=False,
...@@ -811,8 +796,6 @@ class InputMetadata: ...@@ -811,8 +796,6 @@ class InputMetadata:
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_start_loc=extend_start_loc, extend_start_loc=extend_start_loc,
extend_no_prefix=extend_no_prefix, extend_no_prefix=extend_no_prefix,
......
...@@ -245,8 +245,6 @@ class ModelRunner: ...@@ -245,8 +245,6 @@ class ModelRunner:
prefix_lens=batch.prefix_lens, prefix_lens=batch.prefix_lens,
position_ids_offsets=batch.position_ids_offsets, position_ids_offsets=batch.position_ids_offsets,
out_cache_loc=batch.out_cache_loc, out_cache_loc=batch.out_cache_loc,
out_cache_cont_start=batch.out_cache_cont_start,
out_cache_cont_end=batch.out_cache_cont_end,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
) )
......
...@@ -50,6 +50,10 @@ class TokenToKVPool: ...@@ -50,6 +50,10 @@ class TokenToKVPool:
for _ in range(layer_num) for _ in range(layer_num)
] ]
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 256
self.clear() self.clear()
def get_key_buffer(self, layer_id): def get_key_buffer(self, layer_id):
...@@ -59,14 +63,29 @@ class TokenToKVPool: ...@@ -59,14 +63,29 @@ class TokenToKVPool:
return self.kv_data[layer_id][:, 1] return self.kv_data[layer_id][:, 1]
def alloc(self, need_size): def alloc(self, need_size):
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] buffer_len = len(self.prefetch_buffer)
if select_index.shape[0] < need_size: if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index.to(torch.int32)
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size]
if select_index.shape[0] < addition_size:
return None return None
self.add_refs(select_index) self.add_refs(select_index)
return select_index.to(torch.int32)
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index.to(torch.int32)
def alloc_contiguous(self, need_size): def alloc_contiguous(self, need_size):
# NOTE: This function is deprecated.
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size] empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
if empty_index.shape[0] < need_size: if empty_index.shape[0] < need_size:
return None return None
...@@ -89,7 +108,7 @@ class TokenToKVPool: ...@@ -89,7 +108,7 @@ class TokenToKVPool:
return len(torch.nonzero(self.mem_state).squeeze(1)) return len(torch.nonzero(self.mem_state).squeeze(1))
def available_size(self): def available_size(self):
return torch.sum(self.mem_state == 0).item() return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
def add_refs(self, token_index: torch.Tensor): def add_refs(self, token_index: torch.Tensor):
self.total_ref_ct += len(token_index) self.total_ref_ct += len(token_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment