Unverified Commit d2cb3024 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

fix bug that gpu0 occupies more memory when hicache is turned on (#5778)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 1940cdec
...@@ -268,7 +268,7 @@ class HiCacheController: ...@@ -268,7 +268,7 @@ class HiCacheController:
""" """
Directly write through KV caches to host memory without buffering. Directly write through KV caches to host memory without buffering.
""" """
with torch.cuda.stream(self.write_stream): torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True, timeout=1) operation = self.write_queue.get(block=True, timeout=1)
...@@ -291,7 +291,7 @@ class HiCacheController: ...@@ -291,7 +291,7 @@ class HiCacheController:
""" """
Directly load KV caches from host memory to device memory without buffering. Directly load KV caches from host memory to device memory without buffering.
""" """
with torch.cuda.stream(self.load_stream): torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.load_queue.get(block=True, timeout=1) operation = self.load_queue.get(block=True, timeout=1)
...@@ -299,9 +299,7 @@ class HiCacheController: ...@@ -299,9 +299,7 @@ class HiCacheController:
operation.data = self.mem_pool_host.get_flat_data( operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices operation.host_indices
) )
self.mem_pool_device.transfer( self.mem_pool_device.transfer(operation.device_indices, operation.data)
operation.device_indices, operation.data
)
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids: for node_id in operation.node_ids:
if node_id != 0: if node_id != 0:
...@@ -315,7 +313,7 @@ class HiCacheController: ...@@ -315,7 +313,7 @@ class HiCacheController:
""" """
Load KV caches from host memory to device memory layer by layer. Load KV caches from host memory to device memory layer by layer.
""" """
with torch.cuda.stream(self.load_stream): torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1) self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set(): if not self.load_cache_event.is_set():
...@@ -360,6 +358,7 @@ class HiCacheController: ...@@ -360,6 +358,7 @@ class HiCacheController:
""" """
Auxiliary function to prepare the buffer for write operations. Auxiliary function to prepare the buffer for write operations.
""" """
torch.cuda.set_stream(self.write_stream)
def _to_op(op_): def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU" assert op_.device_indices.is_cuda, "Device indices should be on GPU"
...@@ -370,13 +369,11 @@ class HiCacheController: ...@@ -370,13 +369,11 @@ class HiCacheController:
return op_ return op_
buffer = None buffer = None
with torch.cuda.stream(self.write_stream):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True, timeout=1) operation = self.write_queue.get(block=True, timeout=1)
factor = ( factor = (
len(operation.device_indices) len(operation.device_indices) // self.write_buffer.max_buffer_size
// self.write_buffer.max_buffer_size
) )
if factor >= 1: if factor >= 1:
...@@ -484,10 +481,9 @@ class HiCacheController: ...@@ -484,10 +481,9 @@ class HiCacheController:
aux_thread.join() aux_thread.join()
def load_thread_func_buffer(self): def load_thread_func_buffer(self):
torch.cuda.set_stream(self.load_stream)
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start() aux_thread.start()
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
operation = self.load_buffer.get() operation = self.load_buffer.get()
if operation is None: if operation is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment