Unverified Commit d2cb3024 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

fix bug that gpu0 occupies more memory when hicache is turned on (#5778)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 1940cdec
......@@ -268,7 +268,7 @@ class HiCacheController:
"""
Directly write through KV caches to host memory without buffering.
"""
with torch.cuda.stream(self.write_stream):
torch.cuda.set_stream(self.write_stream)
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
......@@ -291,7 +291,7 @@ class HiCacheController:
"""
Directly load KV caches from host memory to device memory without buffering.
"""
with torch.cuda.stream(self.load_stream):
torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
......@@ -299,9 +299,7 @@ class HiCacheController:
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
self.mem_pool_device.transfer(
operation.device_indices, operation.data
)
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
......@@ -315,7 +313,7 @@ class HiCacheController:
"""
Load KV caches from host memory to device memory layer by layer.
"""
with torch.cuda.stream(self.load_stream):
torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
......@@ -360,6 +358,7 @@ class HiCacheController:
"""
Auxiliary function to prepare the buffer for write operations.
"""
torch.cuda.set_stream(self.write_stream)
def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
......@@ -370,13 +369,11 @@ class HiCacheController:
return op_
buffer = None
with torch.cuda.stream(self.write_stream):
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
factor = (
len(operation.device_indices)
// self.write_buffer.max_buffer_size
len(operation.device_indices) // self.write_buffer.max_buffer_size
)
if factor >= 1:
......@@ -484,10 +481,9 @@ class HiCacheController:
aux_thread.join()
def load_thread_func_buffer(self):
torch.cuda.set_stream(self.load_stream)
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start()
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set():
operation = self.load_buffer.get()
if operation is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment