Commit d59f30d4 authored by zhuwenwen's avatar zhuwenwen
Browse files

use custom_all_reduce when bs*hidden_size<=256*4096 (bw)

parent fee048ff
...@@ -54,7 +54,7 @@ class CustomAllreduce: ...@@ -54,7 +54,7 @@ class CustomAllreduce:
def __init__(self, def __init__(self,
group: ProcessGroup, group: ProcessGroup,
device: Union[int, str, torch.device], device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 512) -> None:
""" """
Args: Args:
group: the process group to work on. If None, it will use the group: the process group to work on. If None, it will use the
...@@ -230,7 +230,7 @@ class CustomAllreduce: ...@@ -230,7 +230,7 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
return inp_size < self.max_size return inp_size <= self.max_size
def all_reduce(self, def all_reduce(self,
inp: torch.Tensor, inp: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment