"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "8c2c08df56200ef6c42d93925efc494b009bf9b7"
Unverified Commit c06efdf6 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] FSDP: fix CPU offload corner case (#496)

parent ad611a34
...@@ -967,14 +967,17 @@ class FullyShardedDataParallel(nn.Module): ...@@ -967,14 +967,17 @@ class FullyShardedDataParallel(nn.Module):
# before the move_grads_to_cpu step so that this entire hook remains # before the move_grads_to_cpu step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case. # non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision: if self.mixed_precision:
orig_param_grad_data = param.grad.data
param.grad.data = param.grad.data.to(dtype=param.data.dtype) param.grad.data = param.grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
# Optionally move gradients to CPU, typically used if one is running # Optionally move gradients to CPU, typically used if one is running
# the optimizer on the CPU. # the optimizer on the CPU.
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
param._cpu_grad.copy_(param.grad.data, non_blocking=True) param._cpu_grad.copy_(param.grad.data, non_blocking=False)
# Don't let this memory get reused until after the transfer.
param.grad.data.record_stream(torch.cuda.current_stream())
param.grad.data = param._cpu_grad param.grad.data = param._cpu_grad
# Don't let this memory get reused until after the transfers.
reduced_grad.record_stream(torch.cuda.current_stream())
def _queue_wait_for_post_backward(self) -> None: def _queue_wait_for_post_backward(self) -> None:
"""Try to queue a `wait_for_post_backward` callback. """Try to queue a `wait_for_post_backward` callback.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment