Unverified Commit 5949b1ca authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix memory pool index error (#616)

parent 0feca02d
......@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v
else:
raise RuntimeError()
value_buffer[input_metadata.out_cache_loc] = cache_v
......@@ -132,7 +132,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_()
# Common inputs
......@@ -168,4 +169,4 @@ class CudaGraphRunner:
prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
)
return output
\ No newline at end of file
return output
......@@ -315,7 +315,7 @@ class ModelTpServer:
def get_new_fill_batch(self) -> Optional[Batch]:
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
if running_bs > self.max_running_requests:
if running_bs >= self.max_running_requests:
return
# Compute matched prefix length
......@@ -393,7 +393,7 @@ class ModelTpServer:
else:
break
if running_bs + len(can_run_list) > self.max_running_requests:
if running_bs + len(can_run_list) >= self.max_running_requests:
break
if len(can_run_list) == 0:
......
......@@ -46,7 +46,7 @@ class TokenToKVPool:
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
......@@ -127,4 +127,4 @@ class TokenToKVPool:
self.total_ref_ct = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))
\ No newline at end of file
self.add_refs(torch.tensor([0], dtype=torch.int32))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment