1. 02 May, 2023 1 commit
    • Anthony Chen's avatar
      Use FSDP.STATE_DICT_TYPE = SHARDED_STATE_DICT by default · 5ecbb174
      Anthony Chen authored
      Summary:
      Pull Request resolved: https://github.com/facebookresearch/d2go/pull/535
      
      Use `FSDP.STATE_DICT_TYPE = SHARDED_STATE_DICT` for FSDP checkpointing by default.` FSDP.USE_LOCAL_STATE_DICT` will be deprecated in the future.
      
      # Note
      After the change, config usage of `FSDP.USE_LOCAL_STATE_DICT` will not be picked up by code: it will be superseded by the default type of FSDP.STATE_DICT_TYPE instead
      
      Reviewed By: tglik
      
      Differential Revision: D45413143
      
      fbshipit-source-id: e7bc2d5dc04ac09004cb89353333be020a9c80b5
      5ecbb174
  2. 13 Jan, 2023 1 commit
    • Anthony Chen's avatar
      Support local state dict checkpointing for FSDP · eea6339f
      Anthony Chen authored
      Summary:
      Pull Request resolved: https://github.com/facebookresearch/d2go/pull/457
      
      ## Context:
      
      The Pytorch FSDP (Fully Sharded Data Parallel) backend supports two checkpointing modes. The first one is full_state_dict mode, where each FSDP worker summons parameters from other workers to produce a global state dict that can be loaded by non-FSDP models. This mode is the desired mode for checkpointing because checkpoint structures and key names follows the default convention. It's already supported in D39228316 (https://github.com/facebookresearch/d2go/commit/02625ff83207b836df349eadc4a61eb3d4a5810c)
      
      However, when the model is too large to fit into a single GPU memory, this approach would fail because a worker's GPU can't hold all the summoned parameters during checkpoint saving. The rescue is to use the second checkpointing mode: local_state_dict. This mode saves the sharded parameters in each GPU process locally. It can only be loaded by FSDP-wrapped models with the same distributed training settings (i.e. num processes), but it reduces the need for summoning parameters and greatly saves peak GPU memory during training
      
      This diff enables local state dict checkpointing in d2go.
      
      ## API:
      
      This diff supports both **saving** local state and **loading** state dict that is locally sharded. Whether to save local state is controlled by `FSDP.USE_LOCAL_STATE`. If `FSDP.USE_LOCAL_STATE=True` and we want to save `output/model_0000001.pth` as in the old pattern, the local checkpoints will be saved as:
      ```
      - output
          - model_0000001
              - rank0.pth
              - rank1.pth
              - rank2.pth
              - rank3.pth
      ```
      Whether to load local state, on the other hand, is controlled by the path of the checkpoint to load. If the path is a file, i.e. `output/model_final.pth`, the file will be loaded as a full state dict by all GPU processes like before. If the path is a directory, i.e. `output/model_final`, the checkpointer will attempt to load `output/model_final/rankX.pth` for rank X.
      
      This API design enables the full combinations of loading local/full states and saving local/full states.
      
      ## Conversion to full state dict [Temporary]
      
      Conversion from local state dict to full state dict is needed during an e2e workflow. This will be implemented in another diff
      
      Reviewed By: wat3rBro
      
      Differential Revision: D41861308
      
      fbshipit-source-id: 2e01b601683d06b46f0c5517c6cff30bbcffa8f7
      eea6339f