Commit b14282fb authored by Fanyi Xiao's avatar Fanyi Xiao Committed by Facebook GitHub Bot
Browse files

fix distributed initialization for FSDP

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/657

Without properly set `requires_grad` for params and buffers, it causes hang in FSDP training. This becomes an issue eg when training with LoRA.

Reviewed By: wat3rBro

Differential Revision: D55220828

fbshipit-source-id: 1e33aa540c84c4de62a3a37c48a322aa26c98292
parent abdad994
......@@ -160,12 +160,19 @@ def bottom_up_nested_fsdp(root_module, fsdp_kwargs: Dict[str, Any]):
module,
name,
torch.nn.Parameter(
torch.empty_like(param, device=cuda_device)
torch.empty_like(param, device=cuda_device),
requires_grad=param.requires_grad,
),
)
for name, buffer in module.named_buffers(recurse=False):
setattr(
module, name, torch.empty_like(buffer, device=cuda_device)
module,
name,
torch.empty_like(
buffer,
device=cuda_device,
requires_grad=buffer.requires_grad,
),
)
else:
for _, param in module.named_parameters(recurse=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment