"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "c6a0cacc35c59a66a59d631dff86e8be46479746"
Unverified Commit 5949b1ca authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix memory pool index error (#616)

parent 0feca02d
...@@ -137,9 +137,6 @@ class RadixAttention(nn.Module): ...@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
value_buffer[input_metadata.out_cache_loc] = cache_v value_buffer[input_metadata.out_cache_loc] = cache_v
else:
raise RuntimeError()
...@@ -132,7 +132,8 @@ class CudaGraphRunner: ...@@ -132,7 +132,8 @@ class CudaGraphRunner:
index = bisect.bisect_left(self.batch_size_list, raw_bs) index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index] bs = self.batch_size_list[index]
if bs != raw_bs: if bs != raw_bs:
self.seq_lens.fill_(1) self.seq_lens.zero_()
self.position_ids_offsets.fill_(1)
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
......
...@@ -315,7 +315,7 @@ class ModelTpServer: ...@@ -315,7 +315,7 @@ class ModelTpServer:
def get_new_fill_batch(self) -> Optional[Batch]: def get_new_fill_batch(self) -> Optional[Batch]:
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
if running_bs > self.max_running_requests: if running_bs >= self.max_running_requests:
return return
# Compute matched prefix length # Compute matched prefix length
...@@ -393,7 +393,7 @@ class ModelTpServer: ...@@ -393,7 +393,7 @@ class ModelTpServer:
else: else:
break break
if running_bs + len(can_run_list) > self.max_running_requests: if running_bs + len(can_run_list) >= self.max_running_requests:
break break
if len(can_run_list) == 0: if len(can_run_list) == 0:
......
...@@ -46,7 +46,7 @@ class TokenToKVPool: ...@@ -46,7 +46,7 @@ class TokenToKVPool:
# [size, key/value, head_num, head_dim] for each layer # [size, key/value, head_num, head_dim] for each layer
self.kv_data = [ self.kv_data = [
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda") torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num) for _ in range(layer_num)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment