Unverified Commit 7735e040 authored by AlexWertheim's avatar AlexWertheim Committed by GitHub
Browse files

Enable PyTorch/XLA Fully Sharded Data Parallel (FSDP) (#21406)



* Reinserted import statement accidentally removed during rebasing.

* Added auto_wrap functionality, restructured XLA FSDP logic to more closely match PyTorch FSDP logic.

* Fixed flag descriptions; changed several instances of fsdp_ to xla_fsdp_; pass in auto_wrap_policy and auto_wrapper_callable directly to avoid lambda saving.

* Moved XLA FSDP logic to be adjacent to Fairscale FSDP logic in trainer.

* Formatted changes in accordance with HF style requirements.

* Added back in warning which was accidentally removed.

* - Merged XLA FSDP training arguments into `fsdp_config`
- Added `xla` boolean flag to `fsdp_config` to specify XLA FSDP wrapping
- Merged XLA FSDP wrapping logic into FSDP wrapping logic within trainer
  class

* Cleaned up errors, moved argument to fsdp_config

- Set `xla` and `xla_fsdp_grad_ckpt` flags by default in fsdp_config
- Added missing colons following conditionals
- Moved `fsdp_transformer_layer_cls_to_wrap` to `fsdp_config`
- Modified `fsdp_transformer_layer_cls_to_wrap` to be list of strings,
  not just one string
- Changed Fairscale FSDP logic to allow for set of layer classes to wrap
- Removed unnecessary checks for `xla_fsdp`

* Corrected small errors, improved layer class flag

- Correctly set default values for `xla` and `xla_fsdp_grad_ckpt`
  arguments
- Made `fsdp_transformer_layer_cls_to_wrap` a list of strings instead of
  a single string
- Added processing to ensure that `fsdp_transformer_layer_cls_to_wrap`
  works as expected if passed as a single string
- Updated PyTorch FSDP logic to accept a list of layers to wrap, as done
  with XLA FSDP
- Replaced instances of `getattr()` with `.get()` for dictionary
  retrievals with default values, including when setting
  `fsdp_min_num_params`
- Corrected `self.fsdp is not None` to `len(self.fsdp) > 0`
- Removed extraneous `xla_fsdp` argument descriptions from outside
  `fsdp_config`

* Changed xla-fsdp-settings to be dictionary

- Modified xla-fsdp-settings to be entered directly as dictionary
  instead of loaded through JSON file
- Made small style corrections

* Reverted unintentional local_rank TPU check

* Do not block XLA FSDP if local rank is -1

* Rebased and applied automatic formatting

- Rebased
- Applied automatic formatting changes via `make style`

* Applied automatic formatting with latest version of black

* Replaced  expression with

* Reran black examples tests src utils
ruff examples tests src utils --fix
make autogenerate_code
make[1]: Entering directory '/usr/local/google/home/awertheim/HF-FSDP-PR/transformers'
make[1]: Leaving directory '/usr/local/google/home/awertheim/HF-FSDP-PR/transformers' after additional formatting changes

* Additionall automatic formatting changes

* Remove unnecessary whitespace characters from src/transformers/training_args.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 7f1cdf18
......@@ -417,7 +417,7 @@ class Trainer:
raise ValueError(
"Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
if not args.fsdp_config["xla"] and args.local_rank == -1:
raise ValueError("Using fsdp only works in distributed training.")
# dep_version_check("torch>=1.12.0")
......@@ -1419,55 +1419,110 @@ class Trainer:
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)
auto_wrap_policy = None
auto_wrap_policy = None
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = get_module_class_from_name(
model, self.args.fsdp_transformer_layer_cls_to_wrap
)
if transformer_cls_to_wrap is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls={transformer_cls_to_wrap},
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)
# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
backward_prefetch=self.backward_prefetch,
forward_prefetch=self.forword_prefetch,
limit_all_gathers=self.limit_all_gathers,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
......
......@@ -416,6 +416,9 @@ class TrainingArguments:
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
`T5Block` .... (useful only when `fsdp` flag is passed).
- fsdp_backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).
......@@ -436,6 +439,19 @@ class TrainingArguments:
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
- xla_fsdp_settings (`dict`, *optional*)
The value is a dictionary which stores the XLA FSDP wrapping parameters.
For a complete list of options, please see [here](
https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
- xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
......@@ -898,7 +914,7 @@ class TrainingArguments:
default=0,
metadata={
"help": (
"This parameter is deprecatetd. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
"This parameter is deprecated. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
" only when `fsdp` field is passed)."
)
},
......@@ -916,8 +932,8 @@ class TrainingArguments:
default=None,
metadata={
"help": (
"Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
"(useful only when `fsdp` flag is passed)."
"This parameter is deprecated. Transformer layer class name (case-sensitive) to wrap, e.g,"
" `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed)."
)
},
)
......@@ -1324,23 +1340,53 @@ class TrainingArguments:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
self.fsdp_config["fsdp_min_num_params"] = max(
getattr(self.fsdp_config, "fsdp_min_num_params", 0), self.fsdp_min_num_params
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
)
# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
]
if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]
if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_transformer_layer_cls_to_wrap is not None:
if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_transformer_layer_cls_to_wrap is not None
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
)
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
if len(self.fsdp) > 0:
# store XLA fsdp configuration parameters into a dictionary
self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {})
# apply appropriate string to torch.dtype conversions for parameters
if "compute_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
if "buffer_dtype" in self.xla_fsdp_config:
self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
else:
warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
else:
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
if self.tpu_metrics_debug:
warnings.warn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment