[FSDP] Add an arg for FSDP __init__ (#926)
* [FSDP] Add an arg for FSDP __init__ Add an arg, disable_reshard_on_root, for FSDP __init__ to handle the following issue https://github.com/facebookresearch/fairscale/issues/878 For some cases (models wrapped by autowrap), the parameters (of root modules) needs to be sharded, and reshard_after_forward should not be set to False. "disable_reshard_on_root" is for users to choose whether to force reshard_after_forward of root modules to be False or not. * Update fully_sharded_data_parallel.py Modified the description of the feature to explain more clear. * Update fairscale/nn/data_parallel/fully_sharded_data_parallel.py Update the comments for disable_reshard_on_root Co-authored-by:Min Xu <24926999+min-xu-ai@users.noreply.github.com> * Modified the comments Modified the comments of disable_reshard_on_root Co-authored-by:
Min Xu <24926999+min-xu-ai@users.noreply.github.com>
Showing
Please register or sign in to comment