Commit b14282fb authored by Fanyi Xiao's avatar Fanyi Xiao Committed by Facebook GitHub Bot
Browse files

fix distributed initialization for FSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/657

Without properly set `requires_grad` for params and buffers, it causes hang in FSDP training. This becomes an issue eg when training with LoRA.

Reviewed By: wat3rBro

Differential Revision: D55220828

fbshipit-source-id: 1e33aa540c84c4de62a3a37c48a322aa26c98292
parent abdad994
...@@ -160,12 +160,19 @@ def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]): ...@@ -160,12 +160,19 @@ def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]):
module, module,
name, name,
torch.nn.Parameter( torch.nn.Parameter(
torch.empty_like(param, device=cuda_device) torch.empty_like(param, device=cuda_device),
requires_grad=param.requires_grad,
), ),
) )
for name, buffer in module.named_buffers(recurse=False): for name, buffer in module.named_buffers(recurse=False):
setattr( setattr(
module, name, torch.empty_like(buffer, device=cuda_device) module,
name,
torch.empty_like(
buffer,
device=cuda_device,
requires_grad=buffer.requires_grad,
),
) )
else: else:
for _, param in module.named_parameters(recurse=False): for _, param in module.named_parameters(recurse=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment