Unverified Commit 8bca4f87 authored by liangluofb's avatar liangluofb Committed by GitHub
Browse files

FSDP use _allgather_base and _reduce_scatter_base (#729)



* Update fully_sharded_data_parallel.py

update fully_sharded_data_parallel to use _allgather_base

* Update reduce_scatter_bucketer.py

Use reduce_scatter_base

* Update fully_sharded_data_parallel.py

nonblocking gradient cpu copy, and nonblocking param rebulds

* Update reduce_scatter_bucketer.py

lints

* Update fully_sharded_data_parallel.py

* Update reduce_scatter_bucketer.py

* Update reduce_scatter_bucketer.py

* lints

* linter, test fix

* linter

* LINTERgit add fairscale/utils/reduce_scatter_bucketer.pygit add fairscale/utils/reduce_scatter_bucketer.py

* LINTERgit add tests/nn/data_parallel/test_fsdp_overlap.pygit add tests/nn/data_parallel/test_fsdp_overlap.py

* Update test_fsdp_overlap.py

* Update fairscale/utils/reduce_scatter_bucketer.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>

* Update reduce_scatter_bucketer.py

* isort
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-9-185.ec2.internal>
Co-authored-by: default avatarMin Xu <24926999+min-xu-ai@users.noreply.github.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-77-164.ec2.internal>
parent 782714a8
...@@ -435,7 +435,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -435,7 +435,7 @@ class FullyShardedDataParallel(nn.Module):
@property @property
def params_with_grad(self) -> List[Parameter]: def params_with_grad(self) -> List[Parameter]:
"""[p for p in self.parameters() if p.grad is not None] """ """[p for p in self.parameters() if p.grad is not None]"""
return [p for p in self.parameters() if p.grad is not None] return [p for p in self.parameters() if p.grad is not None]
@torch.no_grad() @torch.no_grad()
...@@ -1315,7 +1315,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1315,7 +1315,7 @@ class FullyShardedDataParallel(nn.Module):
# Optionally move gradients to CPU, typically used if one is running # Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU. # the optimizer on the CPU.
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=False) param._cpu_grad.copy_(param.grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer. # Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream()) param.grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad param.grad.data = param._cpu_grad
...@@ -1448,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1448,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
else: else:
# If self.move_params_to_cpu and force_full_precision, we need to cast # If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device) p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size() p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0 assert p_size.numel() % self.world_size == 0
...@@ -1463,8 +1463,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1463,8 +1463,12 @@ class FullyShardedDataParallel(nn.Module):
output_tensor = p._full_param_padded output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size) # Fill output_tensor with (p.data for each shard in self.world_size)
chunks = list(output_tensor.chunk(self.world_size)) if hasattr(dist, "_all_gather_base"):
dist.all_gather(chunks, p_data, group=self.process_group) # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group) # type: ignore
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
# Set p.data = output_tensor (with padding trimmed) # Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor) update_p_data(output_tensor)
......
...@@ -26,9 +26,14 @@ class Bucket: ...@@ -26,9 +26,14 @@ class Bucket:
assert len(self.callbacks) == 0 assert len(self.callbacks) == 0
return return
# reduce-scatter bucket # reduce-scatter bucket
dist.reduce_scatter( if hasattr(dist, "_reduce_scatter_base"):
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group dist._reduce_scatter_base( # type: ignore
) self.output_shard[: self.offset], self.data[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.data[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks # execute post-reduction callbacks
for callback_fn in self.callbacks: for callback_fn in self.callbacks:
callback_fn() callback_fn()
...@@ -39,12 +44,12 @@ class Bucket: ...@@ -39,12 +44,12 @@ class Bucket:
self.output_shard = torch.zeros_like(self.data[0]) self.output_shard = torch.zeros_like(self.data[0])
def setup(self) -> None: def setup(self) -> None:
""" Setup the buffers if they are not allocated. """Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass memory to other parts of the training process, such as the forward pass
for activation memory. for activation memory.
""" """
for tensor in [self.data, self.output_shard]: for tensor in [self.data, self.output_shard]:
if tensor.storage().size() == 0: if tensor.storage().size() == 0:
...@@ -122,9 +127,15 @@ class ReduceScatterBucketer: ...@@ -122,9 +127,15 @@ class ReduceScatterBucketer:
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size: if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly # input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0]) output = torch.zeros_like(input_list[0])
dist.reduce_scatter(output, input_list, group=group) if hasattr(dist, "_reduce_scatter_base"):
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None: if callback_fn is not None:
callback_fn(output) callback_fn(output)
return return
......
...@@ -150,28 +150,47 @@ class TestNoSyncCommunication(DistributedTest): ...@@ -150,28 +150,47 @@ class TestNoSyncCommunication(DistributedTest):
expected_reduce_scatter = num_fsdp expected_reduce_scatter = num_fsdp
batch = model.module.get_input(torch.device("cuda")) batch = model.module.get_input(torch.device("cuda"))
# depending on pytorch version the _base methods may not be available
method_string_reduce_scatter_base = "torch.distributed._reduce_scatter_base"
if hasattr(torch.distributed, "_reduce_scatter_base") is False:
# no such method, to make mock_reduce_scatter_base 0 invocation, use an impossible name
method_string_reduce_scatter_base = "math.nan" # just an arbitrary function not going to be called
method_string_all_gather_base = "torch.distributed._all_gather_base"
if hasattr(torch.distributed, "_all_gather_base") is False:
# no such method, to make mock_all_gather_base 0 invocation, use an impossible name
method_string_all_gather_base = "math.nan" # just an arbitrary function not going to be called
with patch("torch.distributed.all_gather") as mock_all_gather: with patch("torch.distributed.all_gather") as mock_all_gather:
with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter: with patch("torch.distributed.reduce_scatter") as mock_reduce_scatter:
with model.no_sync(): with patch(method_string_all_gather_base) as mock_all_gather_base:
output = model(*batch) with patch(method_string_reduce_scatter_base) as mock_reduce_scatter_base:
loss = model.module.get_loss(batch, output) with model.no_sync():
loss.backward() output = model(*batch)
loss = model.module.get_loss(batch, output)
assert ( loss.backward()
mock_all_gather.call_count == expected_all_gather1
), f"{mock_all_gather.call_count} != {expected_all_gather1}" # the _base methods are activated when they are available.
assert mock_reduce_scatter.call_count == 0, f"{mock_reduce_scatter.call_count} != 0" # the sum of the _base and public methods should stay the same.
assert (
output = model(*batch) mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather1
loss = model.module.get_loss(batch, output) ), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather1}"
loss.backward() assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count == 0
assert ( ), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != 0"
mock_all_gather.call_count == expected_all_gather2
), f"{mock_all_gather.call_count} != {expected_all_gather2}" output = model(*batch)
assert ( loss = model.module.get_loss(batch, output)
mock_reduce_scatter.call_count == expected_reduce_scatter loss.backward()
), f"{mock_reduce_scatter.call_count} != {expected_reduce_scatter}"
assert (
mock_all_gather.call_count + mock_all_gather_base.call_count == expected_all_gather2
), f"{mock_all_gather.call_count + mock_all_gather_base.call_count} != {expected_all_gather2}"
assert (
mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count
== expected_reduce_scatter
), f"{mock_reduce_scatter.call_count + mock_reduce_scatter_base.call_count} != {expected_reduce_scatter}"
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -93,6 +93,9 @@ def _distributed_worker( ...@@ -93,6 +93,9 @@ def _distributed_worker(
# Save the original torch.distributed.all_gather function since we will # Save the original torch.distributed.all_gather function since we will
# patch it to include an artificial delay. # patch it to include an artificial delay.
orig_all_gather = torch.distributed.all_gather orig_all_gather = torch.distributed.all_gather
orig_all_gather_base = (
torch.distributed._all_gather_base if hasattr(torch.distributed, "_all_gather_base") else None
)
def run(compute_cycles, all_gather_cycles): def run(compute_cycles, all_gather_cycles):
has_params = all_gather_cycles > 0 has_params = all_gather_cycles > 0
...@@ -117,6 +120,7 @@ def _distributed_worker( ...@@ -117,6 +120,7 @@ def _distributed_worker(
cpu_start = time.process_time() cpu_start = time.process_time()
all_gather_called = False all_gather_called = False
all_gather_base_called = False
def _delayed_all_gather(*args, **kwargs): def _delayed_all_gather(*args, **kwargs):
nonlocal all_gather_called nonlocal all_gather_called
...@@ -124,17 +128,30 @@ def _distributed_worker( ...@@ -124,17 +128,30 @@ def _distributed_worker(
torch.cuda._sleep(all_gather_cycles) torch.cuda._sleep(all_gather_cycles)
return orig_all_gather(*args, **kwargs) return orig_all_gather(*args, **kwargs)
def _delayed_all_gather_base(*args, **kwargs):
nonlocal all_gather_base_called
all_gather_base_called = True
torch.cuda._sleep(all_gather_cycles)
assert orig_all_gather_base
return orig_all_gather_base(*args, **kwargs)
method_string_all_gather_base = "torch.distributed._all_gather_base"
if hasattr(torch.distributed, "_all_gather_base") is False:
# no such method, to make mock_all_gather_base 0 invocation, use an impossible name
method_string_all_gather_base = "math.nan"
pass
# forward pass # forward pass
# #
# Even though both e1 & e2 are on the compute stream, since # Even though both e1 & e2 are on the compute stream, since
# compute depends on all_gather, e2-e1 includes all_gather time. # compute depends on all_gather, e2-e1 includes all_gather time.
e1.record() e1.record()
with patch("torch.distributed.all_gather", _delayed_all_gather): with patch("torch.distributed.all_gather", _delayed_all_gather):
out = model(batch) with patch(method_string_all_gather_base, _delayed_all_gather_base):
if has_params and world_size > 1: out = model(batch)
assert all_gather_called if has_params and world_size > 1:
else: assert all_gather_called or all_gather_base_called
assert not all_gather_called else:
assert not all_gather_called and not all_gather_base_called
e2.record() e2.record()
# backward pass # backward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment