Unverified Commit 8bca4f87 authored by liangluofb's avatar liangluofb Committed by GitHub
Browse files

FSDP use _allgather_base and _reduce_scatter_base (#729)



* Update fully_sharded_data_parallel.py

update fully_sharded_data_parallel to use _allgather_base

* Update reduce_scatter_bucketer.py

Use reduce_scatter_base

* Update fully_sharded_data_parallel.py

nonblocking gradient cpu copy, and nonblocking param rebulds

* Update reduce_scatter_bucketer.py

lints

* Update fully_sharded_data_parallel.py

* Update reduce_scatter_bucketer.py

* Update reduce_scatter_bucketer.py

* lints

* linter, test fix

* linter

* LINTERgit add fairscale/utils/reduce_scatter_bucketer.pygit add fairscale/utils/reduce_scatter_bucketer.py

* LINTERgit add tests/nn/data_parallel/test_fsdp_overlap.pygit add tests/nn/data_parallel/test_fsdp_overlap.py

* Update test_fsdp_overlap.py

* Update fairscale/utils/reduce_scatter_bucketer.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Update reduce_scatter_bucketer.py

* isort
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-185.ec2.internal>
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-77-164.ec2.internal>
parent 782714a8
...@@ -435,7 +435,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -435,7 +435,7 @@ class FullyShardedDataParallel(nn.Module):
@property @property
def params_with_grad(self) -> List[Parameter]: def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None] """ """[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None] return [p for p in self.parameters() if p.grad is not None]
@torch.no_grad() @torch.no_grad()
...@@ -1315,7 +1315,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1315,7 +1315,7 @@ class FullyShardedDataParallel(nn.Module):
# Optionally move gradients to CPU, typically used if one is running # Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU. # the optimizer on the CPU.
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=False) param._cpu_grad.copy_(param.grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer. # Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream()) param.grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad param.grad.data = param._cpu_grad
...@@ -1448,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1448,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
else: else:
# If self.move_params_to_cpu and force_full_precision, we need to cast # If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device) p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size() p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0 assert p_size.numel() % self.world_size == 0
...@@ -1463,6 +1463,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1463,6 +1463,10 @@ class FullyShardedDataParallel(nn.Module):
output_tensor = p._full_param_padded output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base"):
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group) # type: ignore
else:
chunks = list(output_tensor.chunk(self.world_size)) chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group) dist.all_gather(chunks, p_data, group=self.process_group)
......
...@@ -26,6 +26,11 @@ class Bucket: ...@@ -26,6 +26,11 @@ class Bucket:
assert len(self.callbacks) == 0 assert len(self.callbacks) == 0
return return
# reduce-scatter bucket # reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base"):
dist._reduce_scatter_base( # type: ignore
self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter( dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
) )
...@@ -39,7 +44,7 @@ class Bucket: ...@@ -39,7 +44,7 @@ class Bucket:
self.output_shard = torch.zeros_like(self.data[0]) self.output_shard = torch.zeros_like(self.data[0])
def setup(self) -> None: def setup(self) -> None:
""" Setup the buffers if they are not allocated. """Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more buffers are only allocated during the backward pass, hence saving more
...@@ -122,8 +127,14 @@ class ReduceScatterBucketer: ...@@ -122,8 +127,14 @@ class ReduceScatterBucketer:
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size: if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly # input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0]) output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base"):
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore
else:
# fallback
dist.reduce_scatter(output, input_list, group=group) dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None: if callback_fn is not None:
callback_fn(output) callback_fn(output)
......
...@@ -150,28 +150,47 @@ class TestNoSyncCommunication(DistributedTest): ...@@ -150,28 +150,47 @@ class TestNoSyncCommunication(DistributedTest):
expected_reduce_scatter = num_fsdp expected_reduce_scatter = num_fsdp
batch = model.module.get_input(torch.device("cuda")) batch = model.module.get_input(torch.device("cuda"))
# depending on pytorch version the _base methods may not be available
method_string_reduce_scatter_base = "torch.distributed._reduce_scatter_base"
if hasattr(torch.distributed, "_reduce_scatter_base") is False:
# no such method, to make mock_reduce_scatter_base 0 invocation, use an impossible name
method_string_reduce_scatter_base = "math.nan" # just an arbitrary function not going to be called
method_string_all_gather_base = "torch.distributed._all_gather_base"
if hasattr(torch.distributed, "_all_gather_base") is False:
# no such method, to make mock_all_gather_base 0 invocation, use an impossible name
method_string_all_gather_base = "math.nan" # just an arbitrary function not going to be called
with patch("torch.distributed.all_gather") as mock_all_gather: with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter: with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with patch(method_string_all_gather_base) as mock_all_gather_base:
with patch(method_string_reduce_scatter_base) as mock_reduce_scatter_base:
with model.no_sync(): with model.no_sync():
output = model(*batch) output = model(*batch)
loss = model.module.get_loss(batch, output) loss = model.module.get_loss(batch, output)
loss.backward() loss.backward()
# the _base methods are activated when they are available.
# the sum of the _base and public methods should stay the same.
assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather1}"
assert ( assert (
mock_all_gather.call_count == expected_all_gather1 mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
), f"{mock_all_gather.call_count} != {expected_all_gather1}" ), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != 0"
assert mock_reduce_scatter.call_count == 0, f"{mock_reduce_scatter.call_count} != 0"
output = model(*batch) output = model(*batch)
loss = model.module.get_loss(batch, output) loss = model.module.get_loss(batch, output)
loss.backward() loss.backward()
assert ( assert (
mock_all_gather.call_count == expected_all_gather2 mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
), f"{mock_all_gather.call_count} != {expected_all_gather2}" ), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}"
assert ( assert (
mock_reduce_scatter.call_count == expected_reduce_scatter mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
), f"{mock_reduce_scatter.call_count} != {expected_reduce_scatter}" == expected_reduce_scatter
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -93,6 +93,9 @@ def _distributed_worker( ...@@ -93,6 +93,9 @@ def _distributed_worker(
# Save the original torch.distributed.all_gather function since we will # Save the original torch.distributed.all_gather function since we will
# patch it to include an artificial delay. # patch it to include an artificial delay.
orig_all_gather = torch.distributed.all_gather orig_all_gather = torch.distributed.all_gather
orig_all_gather_base = (
torch.distributed._all_gather_base if hasattr(torch.distributed, "_all_gather_base") else None
)
def run(compute_cycles, all_gather_cycles): def run(compute_cycles, all_gather_cycles):
has_params = all_gather_cycles > 0 has_params = all_gather_cycles > 0
...@@ -117,6 +120,7 @@ def _distributed_worker( ...@@ -117,6 +120,7 @@ def _distributed_worker(
cpu_start = time.process_time() cpu_start = time.process_time()
all_gather_called = False all_gather_called = False
all_gather_base_called = False
def _delayed_all_gather(*args, **kwargs): def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called nonlocal all_gather_called
...@@ -124,17 +128,30 @@ def _distributed_worker( ...@@ -124,17 +128,30 @@ def _distributed_worker(
torch.cuda._sleep(all_gather_cycles) torch.cuda._sleep(all_gather_cycles)
return orig_all_gather(*args, **kwargs) return orig_all_gather(*args, **kwargs)
def _delayed_all_gather_base(*args, **kwargs):
nonlocal all_gather_base_called
all_gather_base_called = True
torch.cuda._sleep(all_gather_cycles)
assert orig_all_gather_base
return orig_all_gather_base(*args, **kwargs)
method_string_all_gather_base = "torch.distributed._all_gather_base"
if hasattr(torch.distributed, "_all_gather_base") is False:
# no such method, to make mock_all_gather_base 0 invocation, use an impossible name
method_string_all_gather_base = "math.nan"
pass
# forward pass # forward pass
# #
# Even though both e1 & e2 are on the compute stream, since # Even though both e1 & e2 are on the compute stream, since
# compute depends on all_gather, e2-e1 includes all_gather time. # compute depends on all_gather, e2-e1 includes all_gather time.
e1.record() e1.record()
with patch("torch.distributed.all_gather", _delayed_all_gather): with patch("torch.distributed.all_gather", _delayed_all_gather):
with patch(method_string_all_gather_base, _delayed_all_gather_base):
out = model(batch) out = model(batch)
if has_params and world_size > 1: if has_params and world_size > 1:
assert all_gather_called assert all_gather_called or all_gather_base_called
else: else:
assert not all_gather_called assert not all_gather_called and not all_gather_base_called
e2.record() e2.record()
# backward pass # backward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment