Unverified Commit 4a6e7a66 authored by kk's avatar kk Committed by GitHub
Browse files

Fix nan value generated after custom all reduce (#8532)

parent 4b04998d
...@@ -184,7 +184,7 @@ class CustomAllreduce: ...@@ -184,7 +184,7 @@ class CustomAllreduce:
# 8*world_size bytes where world_size is at most 8. Allocating 8MB # 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only # is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples. # needs less than 10000 of registered tuples.
self.rank_data = torch.empty( self.rank_data = torch.zeros(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
) )
self._ptr = ops.init_custom_ar( self._ptr = ops.init_custom_ar(
...@@ -194,14 +194,14 @@ class CustomAllreduce: ...@@ -194,14 +194,14 @@ class CustomAllreduce:
else: else:
# meta data buffers need to be "uncached" for signal on MI200 # meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device) self.buffer = torch.zeros(max_size, dtype=torch.uint8, device=self.device)
handle = ops.get_meta_buffer_ipc_handle(self.meta) handle = ops.get_meta_buffer_ipc_handle(self.meta)
shard_data = ( shard_data = (
bytes(handle), # ipc handle to base ptr bytes(handle), # ipc handle to base ptr
0, # offset of base ptr 0, # offset of base ptr
) )
handles, offsets = self._gather_ipc_meta(shard_data) handles, offsets = self._gather_ipc_meta(shard_data)
self.rank_data = torch.empty( self.rank_data = torch.zeros(
8 * 1024 * 1024, dtype=torch.uint8, device=self.device 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
) )
self._ptr = ops.init_custom_ar( self._ptr = ops.init_custom_ar(
...@@ -350,14 +350,14 @@ class CustomAllreduce: ...@@ -350,14 +350,14 @@ class CustomAllreduce:
# or, in the context of cuda graphs, register_graph_buffers # or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.zeros_like(inp)
ops.all_reduce_reg(self._ptr, inp, out) ops.all_reduce_reg(self._ptr, inp, out)
return out return out
# all reduce, assuming inp tensor is NOT IPC registered # all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.zeros_like(inp)
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out return out
...@@ -375,7 +375,7 @@ class CustomAllreduce: ...@@ -375,7 +375,7 @@ class CustomAllreduce:
buffer. buffer.
""" """
if out is None: if out is None:
out = torch.empty_like(inp) out = torch.zeros_like(inp)
if registered: if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0) ops.all_reduce(self._ptr, inp, out, 0, 0)
else: else:
...@@ -398,7 +398,7 @@ class CustomAllreduce: ...@@ -398,7 +398,7 @@ class CustomAllreduce:
else: else:
# If warm up, mimic the allocation pattern since custom # If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place. # allreduce is out-of-place.
return torch.empty_like(input) return torch.zeros_like(input)
else: else:
if _is_hip: if _is_hip:
# note: outside of cuda graph context, # note: outside of cuda graph context,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment