Unverified Commit 7e4f72dd authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[PD] Add get_contiguous_buf_infos interface for MLATokenToKVPool (#5204)

parent 4c31ae9f
...@@ -442,6 +442,19 @@ class MLATokenToKVPool(KVCache): ...@@ -442,6 +442,19 @@ class MLATokenToKVPool(KVCache):
self.layer_transfer_counter = None self.layer_transfer_counter = None
# for disagg
def get_contiguous_buf_infos(self):
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
kv_data_lens = [
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id) self.layer_transfer_counter.wait_until(layer_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment