Unverified Commit 5949b1ca authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix memory pool index error (#616)

parent 0feca02d
...@@ -137,9 +137,6 @@ class RadixAttention(nn.Module): ...@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None: value_buffer[input_metadata.out_cache_loc] = cache_v
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v
else:
raise RuntimeError()
...@@ -132,7 +132,8 @@ class CudaGraphRunner: ...@@ -132,7 +132,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs) index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index] bs = self.batch_size_list[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
...@@ -168,4 +169,4 @@ class CudaGraphRunner: ...@@ -168,4 +169,4 @@ class CudaGraphRunner:
prefill_top_logprobs=None, prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None, decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
) )
return output return output
\ No newline at end of file
...@@ -315,7 +315,7 @@ class ModelTpServer: ...@@ -315,7 +315,7 @@ class ModelTpServer:
def get_new_fill_batch(self) -> Optional[Batch]: def get_new_fill_batch(self) -> Optional[Batch]:
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
if running_bs > self.max_running_requests: if running_bs >= self.max_running_requests:
return return
# Compute matched prefix length # Compute matched prefix length
...@@ -393,7 +393,7 @@ class ModelTpServer: ...@@ -393,7 +393,7 @@ class ModelTpServer:
else: else:
break break
if running_bs + len(can_run_list) > self.max_running_requests: if running_bs + len(can_run_list) >= self.max_running_requests:
break break
if len(can_run_list) == 0: if len(can_run_list) == 0:
......
...@@ -46,7 +46,7 @@ class TokenToKVPool: ...@@ -46,7 +46,7 @@ class TokenToKVPool:
# [size, key/value, head_num, head_dim] for each layer # [size, key/value, head_num, head_dim] for each layer
self.kv_data = [ self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num) for _ in range(layer_num)
] ]
...@@ -127,4 +127,4 @@ class TokenToKVPool: ...@@ -127,4 +127,4 @@ class TokenToKVPool:
self.total_ref_ct = 0 self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32)) self.add_refs(torch.tensor([0], dtype=torch.int32))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment