"vllm/vscode:/vscode.git/clone" did not exist on "dcee9be95a0f7fce32ab82060733ab31f90b9154"
Commit a8c92908 authored by xiabo's avatar xiabo
Browse files

add custom allreduce cudagraph

parent 2d6bccd9
...@@ -56,7 +56,7 @@ class CustomAllreduce: ...@@ -56,7 +56,7 @@ class CustomAllreduce:
def __init__(self, def __init__(self,
group: ProcessGroup, group: ProcessGroup,
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 1024 * 2) -> None:
""" """
Args: Args:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
...@@ -90,6 +90,7 @@ class CustomAllreduce: ...@@ -90,6 +90,7 @@ class CustomAllreduce:
return return
rank = dist.get_rank(group=self.group) rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group) world_size = dist.get_world_size(group=self.group)
if world_size == 1: if world_size == 1:
# No need to initialize custom allreduce for single GPU case. # No need to initialize custom allreduce for single GPU case.
...@@ -291,7 +292,7 @@ class CustomAllreduce: ...@@ -291,7 +292,7 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=True) return self.all_reduce(input, registered=False)
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
...@@ -306,8 +307,8 @@ class CustomAllreduce: ...@@ -306,8 +307,8 @@ class CustomAllreduce:
if not self.disabled and self._ptr: if not self.disabled and self._ptr:
ops.dispose(self._ptr) ops.dispose(self._ptr)
self._ptr = 0 self._ptr = 0
self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self): def __del__(self):
self.close() self.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment