Unverified Commit d2cb3024 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

fix bug that gpu0 occupies more memory when hicache is turned on (#5778)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 1940cdec
...@@ -268,98 +268,97 @@ class HiCacheController: ...@@ -268,98 +268,97 @@ class HiCacheController:
""" """
Directly write through KV caches to host memory without buffering. Directly write through KV caches to host memory without buffering.
""" """
with torch.cuda.stream(self.write_stream): torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True, timeout=1) operation = self.write_queue.get(block=True, timeout=1)
self.mem_pool_host.write_page_all_layers( self.mem_pool_host.write_page_all_layers(
operation.host_indices, operation.host_indices,
operation.device_indices, operation.device_indices,
self.mem_pool_device, self.mem_pool_device,
) )
self.write_stream.synchronize() self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
if node_id != 0: if node_id != 0:
self.ack_write_queue.put(node_id) self.ack_write_queue.put(node_id)
except Empty: except Empty:
continue continue
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def load_thread_func_direct(self): def load_thread_func_direct(self):
""" """
Directly load KV caches from host memory to device memory without buffering. Directly load KV caches from host memory to device memory without buffering.
""" """
with torch.cuda.stream(self.load_stream): torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.load_queue.get(block=True, timeout=1) operation = self.load_queue.get(block=True, timeout=1)
# time.sleep(18e-6 * len(operation.host_indices)) # time.sleep(18e-6 * len(operation.host_indices))
operation.data = self.mem_pool_host.get_flat_data( operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices operation.host_indices
) )
self.mem_pool_device.transfer( self.mem_pool_device.transfer(operation.device_indices, operation.data)
operation.device_indices, operation.data self.mem_pool_host.complete_io(operation.host_indices)
) for node_id in operation.node_ids:
self.mem_pool_host.complete_io(operation.host_indices) if node_id != 0:
for node_id in operation.node_ids: self.ack_load_queue.put(node_id)
if node_id != 0: except Empty:
self.ack_load_queue.put(node_id) continue
except Empty: except Exception as e:
continue logger.error(e)
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self): def load_thread_func_layer_by_layer(self):
""" """
Load KV caches from host memory to device memory layer by layer. Load KV caches from host memory to device memory layer by layer.
""" """
with torch.cuda.stream(self.load_stream): torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1) self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set(): if not self.load_cache_event.is_set():
continue continue
self.load_cache_event.clear() self.load_cache_event.clear()
batch_operation = None batch_operation = None
while self.load_queue.qsize() > 0: while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True) op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None: if batch_operation is None:
continue batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue
self.layer_done_counter.reset() self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num): for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1: if self.page_size == 1:
flat_data = self.mem_pool_host.get_flat_data_by_layer( flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i batch_operation.host_indices, i
) )
self.mem_pool_device.transfer_per_layer( self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i batch_operation.device_indices, flat_data, i
) )
else: else:
self.mem_pool_host.load_page_per_layer( self.mem_pool_host.load_page_per_layer(
batch_operation.host_indices, batch_operation.host_indices,
batch_operation.device_indices, batch_operation.device_indices,
self.mem_pool_device, self.mem_pool_device,
i, i,
) )
self.load_stream.synchronize() self.load_stream.synchronize()
self.layer_done_counter.increment() self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices) self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids: for node_id in batch_operation.node_ids:
if node_id != 0: if node_id != 0:
self.ack_load_queue.put(node_id) self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False): def write_aux_func(self, no_wait=False):
""" """
Auxiliary function to prepare the buffer for write operations. Auxiliary function to prepare the buffer for write operations.
""" """
torch.cuda.set_stream(self.write_stream)
def _to_op(op_): def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU" assert op_.device_indices.is_cuda, "Device indices should be on GPU"
...@@ -370,44 +369,42 @@ class HiCacheController: ...@@ -370,44 +369,42 @@ class HiCacheController:
return op_ return op_
buffer = None buffer = None
with torch.cuda.stream(self.write_stream): while not self.stop_event.is_set():
while not self.stop_event.is_set(): try:
try: operation = self.write_queue.get(block=True, timeout=1)
operation = self.write_queue.get(block=True, timeout=1) factor = (
factor = ( len(operation.device_indices) // self.write_buffer.max_buffer_size
len(operation.device_indices) )
// self.write_buffer.max_buffer_size
)
if factor >= 1: if factor >= 1:
if buffer is not None: if buffer is not None:
_to_op(buffer)
buffer = None
if factor < 2:
_to_op(operation)
else:
split_ops = operation.split(factor)
for op_ in split_ops:
_to_op(op_)
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer) _to_op(buffer)
buffer = None buffer = None
except Empty:
if factor < 2:
_to_op(operation)
else:
split_ops = operation.split(factor)
for op_ in split_ops:
_to_op(op_)
continue continue
except Exception as e:
logger.error(e) if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def load_aux_func(self): def load_aux_func(self):
""" """
...@@ -484,19 +481,18 @@ class HiCacheController: ...@@ -484,19 +481,18 @@ class HiCacheController:
aux_thread.join() aux_thread.join()
def load_thread_func_buffer(self): def load_thread_func_buffer(self):
torch.cuda.set_stream(self.load_stream)
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start() aux_thread.start()
while not self.stop_event.is_set():
with torch.cuda.stream(self.load_stream): operation = self.load_buffer.get()
while not self.stop_event.is_set(): if operation is None:
operation = self.load_buffer.get() continue
if operation is None: self.mem_pool_device.transfer(operation.device_indices, operation.data)
continue self.mem_pool_host.complete_io(operation.host_indices)
self.mem_pool_device.transfer(operation.device_indices, operation.data) for node_id in operation.node_ids:
self.mem_pool_host.complete_io(operation.host_indices) if node_id != 0:
for node_id in operation.node_ids: self.ack_load_queue.put(node_id)
if node_id != 0:
self.ack_load_queue.put(node_id)
aux_thread.join() aux_thread.join()
def evict_device( def evict_device(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment