Unverified Commit 08f0bd5e authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #1168 from kvcache-ai/Atream-patch-1

remove hard code max_length
parents 22a30d70 e6fb4d5a
......@@ -70,13 +70,11 @@ class QueryInfo:
class QueryManager:
max_length: int = 65536
page_size: int = 256
device: torch.device
query_map : dict[int, QueryInfo]
def __init__(self, max_length = 65536, page_size = 256, device = torch.device('cuda')):
self.max_length = max_length
def __init__(self, page_size = 256, device = torch.device('cuda')):
self.page_size = page_size
self.device = device
self.query_map = {}
......@@ -87,7 +85,6 @@ class QueryManager:
id = batch.query_ids[i]
if id not in self.query_map:
print(f"add query id: {id}, batch.query_lengths: {batch.query_lengths[i]}, batch_query_tokens: {batch.query_tokens[i].shape}, batch.block_indexes: {batch.block_indexes[i]}")
assert batch.query_tokens[i].size(0) < self.max_length, "query max length in batchquerytodo exceeds internal max_length"
query_info = QueryInfo(id=id, query_length=batch.query_lengths[i], max_length=batch.query_tokens[i].size(0) + 1, page_size=self.page_size, device=self.device, temperature=batch.sample_options[i].temperature, top_p=batch.sample_options[i].top_p)
query_info.query_tokens[:query_info.query_length].copy_(batch.query_tokens[i][:query_info.query_length].to(self.device))
......@@ -155,4 +152,4 @@ class QueryManager:
query_update.active_position = query_info.active_position
query_updates.append(query_update)
return query_updates
\ No newline at end of file
return query_updates
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment