Add support for FSDP SHARDED_STATE_DICT in D2Go
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/512 Currently, when saving and loading checkpoints for FSDP-wrapped modules, we are saving and loading using `StateDictType.LOCAL_STATE_DICT`, where the state_dict becomes essentially a single flat tensor under the `_flat_param` key (or some other layer-specific key for flat weights). This means that 1. It's impossible to load weights directly from checkpoints, for example in notebooks 2. Converting from a local to a global checkpoint requires running a special workflow (https://fburl.com/code/6yqa4ldb) that occupies the same number of GPUs as was used during training This diff adds an option, `FSDP.STATE_DICT_TYPE`, which allows selection of the type of state dict to save (local, sharded, full). In sharded mode, with AIF checkpointing, we are able to have the benefit of allowing local loading of state dicts in minutes with any number of GPUs, in notebooks and elsewhere. Note: for backwards compatibility, `CFG.FSDP.use_local_state_dict` and `CFG.FSDP.load_local_state_dict` still need to work when the new config parameter (`CFG.FSDP.state_dict_type`) is not set. Also, it's used to signify that local/sharded state dicts need to be converted to a full state dict when loading. This functionality can be deprecated when everyone migrates to AIF checkpointing with sharded dicts. Reviewed By: YanjunChen329 Differential Revision: D43840887 fbshipit-source-id: d112f7b7ad97ba82fd5bf1da986b95ad7fc61c42
Showing
Please register or sign in to comment