Unverified Commit da2a4d95 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Add support of backward_prefetch and forward_prefetch (#21237)



* Add support of backward_prefetch and forward_prefetch

* Fix format issue

* Fix isort issue

* Fix doc style issue

* Update src/transformers/trainer.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Fix black issue

* Fix doc-style issue

* Make additional fsdp parameters into fsdp config

* Fix black issue

* Remove unused imports

* Fix doc style issues

* Incorporate PR feedbacks

* Remove unused imports

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Fix tests

* Incorporate PR feedbacks

* Incorporate PR feedbacks

* Fix black issues

---------
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
parent 074d6b75
...@@ -408,7 +408,7 @@ class Trainer: ...@@ -408,7 +408,7 @@ class Trainer:
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0") raise ValueError("FSDP requires PyTorch >= 1.12.0")
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
if FSDPOption.FULL_SHARD in args.fsdp: if FSDPOption.FULL_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.FULL_SHARD self.fsdp = ShardingStrategy.FULL_SHARD
...@@ -417,6 +417,14 @@ class Trainer: ...@@ -417,6 +417,14 @@ class Trainer:
elif FSDPOption.NO_SHARD in args.fsdp: elif FSDPOption.NO_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.NO_SHARD self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.forword_prefetch = False
if "forword_prefetch" in self.args.fsdp_config and self.backward_prefetch:
self.forword_prefetch = True
# one place to sort out whether to place the model on device or not # one place to sort out whether to place the model on device or not
# postpone switching model to cuda when: # postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model # 1. MP - since we are trying to fit a much bigger than 1 gpu model
...@@ -1401,10 +1409,11 @@ class Trainer: ...@@ -1401,10 +1409,11 @@ class Trainer:
cpu_offload = CPUOffload(offload_params=False) cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp: if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_min_num_params > 0: if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial( auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
) )
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None: elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = get_module_class_from_name( transformer_cls_to_wrap = get_module_class_from_name(
...@@ -1434,6 +1443,8 @@ class Trainer: ...@@ -1434,6 +1443,8 @@ class Trainer:
auto_wrap_policy=auto_wrap_policy, auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy, mixed_precision=mixed_precision_policy,
device_id=self.args.device, device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
) )
elif is_sagemaker_dp_enabled(): elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import io
import json import json
import math import math
import os import os
...@@ -407,8 +408,30 @@ class TrainingArguments: ...@@ -407,8 +408,30 @@ class TrainingArguments:
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`). `"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`. - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_min_num_params (`int`, *optional*, defaults to `0`): fsdp_config (`str` or `dict`, *optional*):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed). Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
A List of config and its options:
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- fsdp_backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).
A list of options along the following:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
deepspeed (`str` or `dict`, *optional*): deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g., evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
...@@ -871,8 +894,17 @@ class TrainingArguments: ...@@ -871,8 +894,17 @@ class TrainingArguments:
default=0, default=0,
metadata={ metadata={
"help": ( "help": (
"FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is" "This parameter is deprecatetd. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
" passed)." " only when `fsdp` field is passed)."
)
},
)
fsdp_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a"
"fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
) )
}, },
) )
...@@ -1278,13 +1310,31 @@ class TrainingArguments: ...@@ -1278,13 +1310,31 @@ class TrainingArguments:
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0: if self.fsdp_config is None:
self.fsdp_config = {}
if isinstance(self.fsdp_config, str):
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
self.fsdp_config["fsdp_min_num_params"] = max(
getattr(self.fsdp_config, "fsdp_min_num_params", 0), self.fsdp_min_num_params
)
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None: if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if len(self.fsdp) > 0 and self.fsdp_min_num_params > 0 and self.fsdp_transformer_layer_cls_to_wrap is not None: if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_transformer_layer_cls_to_wrap is not None
):
raise ValueError( raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment