Enable PyTorch/XLA Fully Sharded Data Parallel (FSDP) (#21406)
* Reinserted import statement accidentally removed during rebasing. * Added auto_wrap functionality, restructured XLA FSDP logic to more closely match PyTorch FSDP logic. * Fixed flag descriptions; changed several instances of fsdp_ to xla_fsdp_; pass in auto_wrap_policy and auto_wrapper_callable directly to avoid lambda saving. * Moved XLA FSDP logic to be adjacent to Fairscale FSDP logic in trainer. * Formatted changes in accordance with HF style requirements. * Added back in warning which was accidentally removed. * - Merged XLA FSDP training arguments into `fsdp_config` - Added `xla` boolean flag to `fsdp_config` to specify XLA FSDP wrapping - Merged XLA FSDP wrapping logic into FSDP wrapping logic within trainer class * Cleaned up errors, moved argument to fsdp_config - Set `xla` and `xla_fsdp_grad_ckpt` flags by default in fsdp_config - Added missing colons following conditionals - Moved `fsdp_transformer_layer_cls_to_wrap` to `fsdp_config` - Modified `fsdp_transformer_layer_cls_to_wrap` to be list of strings, not just one string - Changed Fairscale FSDP logic to allow for set of layer classes to wrap - Removed unnecessary checks for `xla_fsdp` * Corrected small errors, improved layer class flag - Correctly set default values for `xla` and `xla_fsdp_grad_ckpt` arguments - Made `fsdp_transformer_layer_cls_to_wrap` a list of strings instead of a single string - Added processing to ensure that `fsdp_transformer_layer_cls_to_wrap` works as expected if passed as a single string - Updated PyTorch FSDP logic to accept a list of layers to wrap, as done with XLA FSDP - Replaced instances of `getattr()` with `.get()` for dictionary retrievals with default values, including when setting `fsdp_min_num_params` - Corrected `self.fsdp is not None` to `len(self.fsdp) > 0` - Removed extraneous `xla_fsdp` argument descriptions from outside `fsdp_config` * Changed xla-fsdp-settings to be dictionary - Modified xla-fsdp-settings to be entered directly as dictionary instead of loaded through JSON file - Made small style corrections * Reverted unintentional local_rank TPU check * Do not block XLA FSDP if local rank is -1 * Rebased and applied automatic formatting - Rebased - Applied automatic formatting changes via `make style` * Applied automatic formatting with latest version of black * Replaced expression with * Reran black examples tests src utils ruff examples tests src utils --fix make autogenerate_code make[1]: Entering directory '/usr/local/google/home/awertheim/HF-FSDP-PR/transformers' make[1]: Leaving directory '/usr/local/google/home/awertheim/HF-FSDP-PR/transformers' after additional formatting changes * Additionall automatic formatting changes * Remove unnecessary whitespace characters from src/transformers/training_args.py Co-authored-by:Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by:
Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Showing
Please register or sign in to comment