Unverified Commit da2a4d95 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Add support of backward_prefetch and forward_prefetch (#21237)



* Add support of backward_prefetch and forward_prefetch

* Fix format issue

* Fix isort issue

* Fix doc style issue

* Update src/transformers/trainer.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Fix black issue

* Fix doc-style issue

* Make additional fsdp parameters into fsdp config

* Fix black issue

* Remove unused imports

* Fix doc style issues

* Incorporate PR feedbacks

* Remove unused imports

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Fix tests

* Update src/transformers/training_args.py
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>

* Fix tests

* Incorporate PR feedbacks

* Incorporate PR feedbacks

* Fix black issues

---------
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
parent 074d6b75
......@@ -408,7 +408,7 @@ class Trainer:
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
raise ValueError("FSDP requires PyTorch >= 1.12.0")
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
if FSDPOption.FULL_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.FULL_SHARD
......@@ -417,6 +417,14 @@ class Trainer:
elif FSDPOption.NO_SHARD in args.fsdp:
self.fsdp = ShardingStrategy.NO_SHARD
self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch:
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
self.forword_prefetch = False
if "forword_prefetch" in self.args.fsdp_config and self.backward_prefetch:
self.forword_prefetch = True
# one place to sort out whether to place the model on device or not
# postpone switching model to cuda when:
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
......@@ -1401,10 +1409,11 @@ class Trainer:
cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_min_num_params > 0:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_min_num_params
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = get_module_class_from_name(
......@@ -1434,6 +1443,8 @@ class Trainer:
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
)
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import io
import json
import math
import os
......@@ -407,8 +408,30 @@ class TrainingArguments:
- `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
`"shard_grad_op"`).
- `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
fsdp_min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed).
fsdp_config (`str` or `dict`, *optional*):
Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
A List of config and its options:
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- fsdp_backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).
A list of options along the following:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
......@@ -871,8 +894,17 @@ class TrainingArguments:
default=0,
metadata={
"help": (
"FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is"
" passed)."
"This parameter is deprecatetd. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
" only when `fsdp` field is passed)."
)
},
)
fsdp_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a"
"fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
)
},
)
......@@ -1278,13 +1310,31 @@ class TrainingArguments:
elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
if len(self.fsdp) == 0 and self.fsdp_min_num_params > 0:
if self.fsdp_config is None:
self.fsdp_config = {}
if isinstance(self.fsdp_config, str):
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
self.fsdp_config["fsdp_min_num_params"] = max(
getattr(self.fsdp_config, "fsdp_min_num_params", 0), self.fsdp_min_num_params
)
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if len(self.fsdp) > 0 and self.fsdp_min_num_params > 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_transformer_layer_cls_to_wrap is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment