Commit feb74214 authored by Chien-Chin Huang's avatar Chien-Chin Huang Committed by Facebook GitHub Bot
Browse files

Add the missing optimizer argument

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/528

Not passing optimizer object to shard_full_optim_state_dict() is being deprecated. This diff passes optimizer to shard_full_optim_state_dict().

Reviewed By: YanjunChen329

Differential Revision: D45065185

fbshipit-source-id: 0abec3eeff6e7c626eefc432c73e38779a6f02d9
parent 8353ad23
......@@ -14,9 +14,11 @@ def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
"""
# FSDP: full_optim_state_dict() needs to be called by all ranks
if model.state_dict_type == StateDictType.FULL_STATE_DICT:
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only)
return FSDP.full_optim_state_dict(
model, optim=optimizer, rank0_only=model.rank0_only
)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
return FSDP.sharded_optim_state_dict(model, optimizer)
return FSDP.sharded_optim_state_dict(model, optim=optimizer)
return optimizer.state_dict()
......@@ -26,10 +28,12 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
If using full state dict, shard and scatter the optimizer state dict before loading
"""
if model.load_state_dict_type == StateDictType.FULL_STATE_DICT:
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model)
optim_state_dict = FSDP.shard_full_optim_state_dict(
optim_state_dict, model, optim=optimizer
)
elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT:
optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
optim_state_dict, model, optimizer
optim_state_dict, model, optim=optimizer
)
optimizer.load_state_dict(optim_state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment