Commit feb74214 authored by Chien-Chin Huang's avatar Chien-Chin Huang Committed by Facebook GitHub Bot
Browse files

Add the missing optimizer argument

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/528

Not passing optimizer object to shard_full_optim_state_dict() is being deprecated. This diff passes optimizer to shard_full_optim_state_dict().

Reviewed By: YanjunChen329

Differential Revision: D45065185

fbshipit-source-id: 0abec3eeff6e7c626eefc432c73e38779a6f02d9
parent 8353ad23
...@@ -14,9 +14,11 @@ def gather_optimizer_state_dict(optimizer, model: FSDPWrapper): ...@@ -14,9 +14,11 @@ def gather_optimizer_state_dict(optimizer, model: FSDPWrapper):
""" """
# FSDP: full_optim_state_dict() needs to be called by all ranks # FSDP: full_optim_state_dict() needs to be called by all ranks
if model.state_dict_type == StateDictType.FULL_STATE_DICT: if model.state_dict_type == StateDictType.FULL_STATE_DICT:
return FSDP.full_optim_state_dict(model, optimizer, rank0_only=model.rank0_only) return FSDP.full_optim_state_dict(
model, optim=optimizer, rank0_only=model.rank0_only
)
elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT: elif model.state_dict_type == StateDictType.SHARDED_STATE_DICT:
return FSDP.sharded_optim_state_dict(model, optimizer) return FSDP.sharded_optim_state_dict(model, optim=optimizer)
return optimizer.state_dict() return optimizer.state_dict()
...@@ -26,10 +28,12 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper ...@@ -26,10 +28,12 @@ def scatter_optimizer_state_dict(optimizer, optim_state_dict, model: FSDPWrapper
If using full state dict, shard and scatter the optimizer state dict before loading If using full state dict, shard and scatter the optimizer state dict before loading
""" """
if model.load_state_dict_type == StateDictType.FULL_STATE_DICT: if model.load_state_dict_type == StateDictType.FULL_STATE_DICT:
optim_state_dict = FSDP.shard_full_optim_state_dict(optim_state_dict, model) optim_state_dict = FSDP.shard_full_optim_state_dict(
optim_state_dict, model, optim=optimizer
)
elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT: elif model.load_state_dict_type == StateDictType.SHARDED_STATE_DICT:
optim_state_dict = FSDP.flatten_sharded_optim_state_dict( optim_state_dict = FSDP.flatten_sharded_optim_state_dict(
optim_state_dict, model, optimizer optim_state_dict, model, optim=optimizer
) )
optimizer.load_state_dict(optim_state_dict) optimizer.load_state_dict(optim_state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment