Unverified Commit d0730487 authored by pansicheng's avatar pansicheng Committed by GitHub
Browse files

fix 3fs zerocopy (#9938)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent b32ab070
...@@ -324,6 +324,22 @@ class HiCacheController: ...@@ -324,6 +324,22 @@ class HiCacheController:
group_ranks, backend="gloo" group_ranks, backend="gloo"
) )
# Select the get and set functions
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
self.batch_exists_func = self.storage_backend.batch_exists
self.is_3fs_zerocopy = (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
...@@ -617,13 +633,19 @@ class HiCacheController: ...@@ -617,13 +633,19 @@ class HiCacheController:
for chunk in chunks: for chunk in chunks:
self.host_mem_release_queue.put(chunk) self.host_mem_release_queue.put(chunk)
def _3fs_zero_copy_batch_exists(self, batch_hashes):
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
return hit_page_num
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices hash_values, host_indices
) )
page_data = self.storage_backend.batch_get(hashes, dsts) page_data = self.storage_backend.batch_get(hashes, dsts)
if page_data: if page_data:
operation.increment(self.page_size * len(hashes)) inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else: else:
logger.warning( logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
...@@ -670,17 +692,6 @@ class HiCacheController: ...@@ -670,17 +692,6 @@ class HiCacheController:
break # Operation terminated by controller break # Operation terminated by controller
def _page_transfer(self, operation): def _page_transfer(self, operation):
# Select the get function and batch size
if self.storage_backend_type == "mooncake":
get_func = self._mooncake_page_get
elif (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
):
get_func = self._3fs_zero_copy_page_get
else:
get_func = self._generic_page_get
# Transfer batch by batch # Transfer batch by batch
for i in range(0, len(operation.hash_value), self.storage_batch_size): for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
...@@ -689,7 +700,7 @@ class HiCacheController: ...@@ -689,7 +700,7 @@ class HiCacheController:
] ]
prev_completed_tokens = operation.completed_tokens prev_completed_tokens = operation.completed_tokens
# Get one batch token, and update the completed_tokens if succeed # Get one batch token, and update the completed_tokens if succeed
get_func(operation, batch_hashes, batch_host_indices) self.page_get_func(operation, batch_hashes, batch_host_indices)
# Check termination # Check termination
if ( if (
operation.completed_tokens operation.completed_tokens
...@@ -746,7 +757,7 @@ class HiCacheController: ...@@ -746,7 +757,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash batch_tokens[i : i + self.page_size], last_hash
) )
batch_hashes.append(last_hash) batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes) hit_page_num = self.batch_exists_func(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num]) hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes): if hit_page_num < len(batch_hashes):
...@@ -839,23 +850,13 @@ class HiCacheController: ...@@ -839,23 +850,13 @@ class HiCacheController:
# zero copy # zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool: def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices hash_values, host_indices
) )
return self.storage_backend.batch_set(hashes, dsts) return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch # Backup batch by batch
def _page_backup(self, operation): def _page_backup(self, operation):
# Select the set function and batch size
if self.storage_backend_type == "mooncake":
backup_set_func = self._mooncake_page_set
elif (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
):
backup_set_func = self._3fs_zero_copy_page_set
else:
backup_set_func = self._generic_page_set
# Backup batch by batch # Backup batch by batch
for i in range(0, len(operation.hash_value), self.storage_batch_size): for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size] batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
...@@ -864,7 +865,7 @@ class HiCacheController: ...@@ -864,7 +865,7 @@ class HiCacheController:
] ]
# Set one batch token, and record if success. # Set one batch token, and record if success.
# todo: allow partial success # todo: allow partial success
success = backup_set_func(batch_hashes, batch_host_indices) success = self.page_set_func(batch_hashes, batch_host_indices)
if not success: if not success:
logger.warning( logger.warning(
f"Write page to storage: {len(batch_hashes)} pages failed." f"Write page to storage: {len(batch_hashes)} pages failed."
......
...@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices): def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first" assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size) assert indices is None or (len(keys) == (len(indices) // self.page_size))
key_list = [] key_list = []
buf_list = [] buf_list = []
for key, i in zip(keys, range(0, len(indices), self.page_size)): for i in range(len(keys)):
key = keys[i]
key_list.append(f"{key}-k") key_list.append(f"{key}-k")
buf_list.append(self.k_buffer[i : i + self.page_size])
key_list.append(f"{key}-v") key_list.append(f"{key}-v")
buf_list.append(self.v_buffer[i : i + self.page_size]) if indices is not None:
index = indices[i * self.page_size]
buf_list.append(self.k_buffer[index : index + self.page_size])
buf_list.append(self.v_buffer[index : index + self.page_size])
return key_list, buf_list return key_list, buf_list, 2
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
...@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices): def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first" assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size) assert indices is None or (len(keys) == (len(indices) // self.page_size))
buf_list = [] buf_list = []
for i in range(0, len(indices), self.page_size): if indices is not None:
buf_list.append(self.kv_buffer[i : i + self.page_size]) for i in range(len(keys)):
index = indices[i * self.page_size]
buf_list.append(self.kv_buffer[index : index + self.page_size])
return keys, buf_list return keys, buf_list, 1
...@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
if self.is_page_first_layout and not self.is_mla_model: results = self.metadata_client.exists(self.rank, keys)
query_keys = [] for i in range(len(keys)):
# Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash if not results[i]:
for key in keys: return i
query_keys.append(f"{key}-k")
query_keys.append(f"{key}-v") return len(keys)
key_multiplier = 2
else:
query_keys = keys
key_multiplier = 1
exist_result = self.metadata_client.exists(self.rank, query_keys)
for i in range(len(query_keys)):
if not exist_result[i]:
return i // key_multiplier
return len(query_keys) // key_multiplier
def clear(self) -> bool: def clear(self) -> bool:
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment