Unverified Commit 9d14be5c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add support for ZeRO-2/3 and ZeRO-offload in fairscale (#10354)



* Ass support for ZeRO-2/3 and ZeRO-offload in fairscale

* Quality

* Rework from review comments

* Add doc

* Apply suggestions from code review
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Address review comments
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
parent 88cc26dc
...@@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv. ...@@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv.
1. Optimizer State Sharding 1. Optimizer State Sharding
2. Gradient Sharding 2. Gradient Sharding
3. Model Parameters Sharding (new and very experimental)
4. CPU offload (new and very experimental)
You will need at least two GPUs to use this feature. You will need at least two GPUs to use this feature.
...@@ -255,8 +257,9 @@ To deploy this feature: ...@@ -255,8 +257,9 @@ To deploy this feature:
or find more details on `the FairScale's GitHub page or find more details on `the FairScale's GitHub page
<https://github.com/facebookresearch/fairscale/#installation>`__. <https://github.com/facebookresearch/fairscale/#installation>`__.
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m 2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already. and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
...@@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs: ...@@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
--do_train --max_train_samples 500 --num_train_epochs 1 \ --do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \ --dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \ --task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--fp16 --sharded_ddp --fp16 --sharded_ddp simple
Notes: Notes:
- This feature requires distributed training (so multiple GPUs). - This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs. - It is not implemented for TPUs.
- It works with ``--fp16`` too, to make things even faster. - It works with ``--fp16`` too, to make things even faster.
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able - One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
significantly shorter training time. significantly shorter training time.
3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
.. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--fp16 --sharded_ddp zero_dp_2
:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
gradients and optimizer states.
Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activate it like this: :obj:`--sharded_ddp
"zero_dp_2 cpu_offload"`).
Notes:
- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with ``--fp16`` too, to make things even faster.
- The ``cpu_offload`` additional option requires ``--fp16``.
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
some bugs you encounter may have been fixed there already.
Known caveats:
- This feature is incompatible with :obj:`--predict_with_generate` in the `run_seq2seq.py` script.
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example scripts of the
:class:`~transformers.Trainer`.
DeepSpeed DeepSpeed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -64,12 +64,13 @@ def require_apex(test_case): ...@@ -64,12 +64,13 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus): class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None): def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str) output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats if predict_with_generate:
assert "eval_bleu" in first_step_stats
@require_torch_non_multi_gpu @require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self): def test_run_seq2seq_no_dist(self):
...@@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus): ...@@ -88,14 +89,28 @@ class TestTrainerExt(TestCasePlus):
# test --sharded_ddp w/o --fp16 # test --sharded_ddp w/o --fp16
@require_torch_multi_gpu @require_torch_multi_gpu
@require_fairscale @require_fairscale
def test_run_seq2seq_ddp_sharded_ddp(self): def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp") self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
# test --sharded_ddp w/ --fp16 # test --sharded_ddp w/ --fp16
@require_torch_multi_gpu @require_torch_multi_gpu
@require_fairscale @require_fairscale
def test_run_seq2seq_ddp_sharded_ddp_fp16(self): def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16") self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero2 w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)
# test --sharded_ddp zero2 w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
)
@require_apex @require_apex
def test_run_seq2seq_apex(self): def test_run_seq2seq_apex(self):
...@@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus): ...@@ -131,6 +146,7 @@ class TestTrainerExt(TestCasePlus):
num_train_epochs: int, num_train_epochs: int,
distributed: bool = False, distributed: bool = False,
extra_args_str: str = None, extra_args_str: str = None,
predict_with_generate: bool = True,
): ):
data_dir = self.examples_dir / "test_data/wmt_en_ro" data_dir = self.examples_dir / "test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
...@@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus): ...@@ -155,7 +171,6 @@ class TestTrainerExt(TestCasePlus):
--learning_rate 3e-3 --learning_rate 3e-3
--warmup_steps 8 --warmup_steps 8
--evaluation_strategy steps --evaluation_strategy steps
--predict_with_generate
--logging_steps 0 --logging_steps 0
--save_steps {str(eval_steps)} --save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)} --eval_steps {str(eval_steps)}
...@@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus): ...@@ -165,7 +180,11 @@ class TestTrainerExt(TestCasePlus):
--task translation --task translation
--target_lang ro_RO --target_lang ro_RO
--source_lang en_XX --source_lang en_XX
""".split() """
if predict_with_generate:
args += "--predict_with_generate"
args = args.split()
if extra_args_str is not None: if extra_args_str is not None:
args.extend(extra_args_str.split()) args.extend(extra_args_str.split())
......
...@@ -93,6 +93,7 @@ from .trainer_utils import ( ...@@ -93,6 +93,7 @@ from .trainer_utils import (
EvalPrediction, EvalPrediction,
HPSearchBackend, HPSearchBackend,
PredictionOutput, PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker, TrainerMemoryTracker,
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
...@@ -131,10 +132,16 @@ if is_torch_tpu_available(): ...@@ -131,10 +132,16 @@ if is_torch_tpu_available():
import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available(): if is_fairscale_available():
import fairscale
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
else:
FullyShardedDDP = None
if is_sagemaker_distributed_available(): if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
...@@ -277,9 +284,38 @@ class Trainer: ...@@ -277,9 +284,38 @@ class Trainer:
else: else:
self.is_model_parallel = False self.is_model_parallel = False
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
# one place to sort out whether to place the model on device or not # one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train): if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
self.place_model_on_device = False self.place_model_on_device = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
...@@ -346,21 +382,6 @@ class Trainer: ...@@ -346,21 +382,6 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset): if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation") self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Setup Sharded DDP training
self.sharded_dpp = False
if args.sharded_ddp:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
)
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
else:
self.sharded_dpp = True
# Mixed precision setup # Mixed precision setup
self.use_apex = False self.use_apex = False
self.use_amp = False self.use_amp = False
...@@ -376,7 +397,7 @@ class Trainer: ...@@ -376,7 +397,7 @@ class Trainer:
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp": if self.fp16_backend == "amp":
self.use_amp = True self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler() self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
else: else:
if not is_apex_available(): if not is_apex_available():
raise ImportError( raise ImportError(
...@@ -619,7 +640,7 @@ class Trainer: ...@@ -619,7 +640,7 @@ class Trainer:
"eps": self.args.adam_epsilon, "eps": self.args.adam_epsilon,
} }
optimizer_kwargs["lr"] = self.args.learning_rate optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp: if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS( self.optimizer = OSS(
params=optimizer_grouped_parameters, params=optimizer_grouped_parameters,
optim=optimizer_cls, optim=optimizer_cls,
...@@ -737,8 +758,19 @@ class Trainer: ...@@ -737,8 +758,19 @@ class Trainer:
return model return model
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp: if self.sharded_ddp is not None:
model = ShardedDDP(model, self.optimizer) # Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
self.model = model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
).to(self.args.device)
elif is_sagemaker_distributed_available(): elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1: elif self.args.local_rank != -1:
...@@ -855,6 +887,7 @@ class Trainer: ...@@ -855,6 +887,7 @@ class Trainer:
num_train_epochs = 1 num_train_epochs = 1
num_update_steps_per_epoch = max_steps num_update_steps_per_epoch = max_steps
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed: if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module self.model = model.module
...@@ -862,7 +895,7 @@ class Trainer: ...@@ -862,7 +895,7 @@ class Trainer:
self.deepspeed = model # DeepSpeedEngine object self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer self.optimizer = optimizer
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
else: elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState() self.state = TrainerState()
...@@ -877,6 +910,9 @@ class Trainer: ...@@ -877,6 +910,9 @@ class Trainer:
if model is not self.model: if model is not self.model:
self.model_wrapped = model self.model_wrapped = model
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
# important: at this point: # important: at this point:
# self.model is the Transformers Model # self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
...@@ -1026,6 +1062,9 @@ class Trainer: ...@@ -1026,6 +1062,9 @@ class Trainer:
if hasattr(self.optimizer, "clip_grad_norm"): if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm) self.optimizer.clip_grad_norm(self.args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
else: else:
# Revert to normal clipping otherwise, handling Apex or full precision # Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(
...@@ -1148,8 +1187,8 @@ class Trainer: ...@@ -1148,8 +1187,8 @@ class Trainer:
def _save_checkpoint(self, model, trial, metrics=None): def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save. # want to save except FullyShardedDDP.
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" # assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# Save model checkpoint # Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
...@@ -1173,7 +1212,7 @@ class Trainer: ...@@ -1173,7 +1212,7 @@ class Trainer:
self.deepspeed.save_checkpoint(output_dir) self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler # Save optimizer and scheduler
if self.sharded_dpp: if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict() self.optimizer.consolidate_state_dict()
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -1479,7 +1518,11 @@ class Trainer: ...@@ -1479,7 +1518,11 @@ class Trainer:
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint") xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if isinstance(_model_unwrap(self.model), PreTrainedModel):
if xm.is_master_ordinal():
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
...@@ -1494,7 +1537,10 @@ class Trainer: ...@@ -1494,7 +1537,10 @@ class Trainer:
# Save a trained model and configuration using `save_pretrained()`. # Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if isinstance(_model_unwrap(self.model), PreTrainedModel):
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict() state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
......
...@@ -421,3 +421,10 @@ class TrainerMemoryTracker: ...@@ -421,3 +421,10 @@ class TrainerMemoryTracker:
# init doesn't have metrics to update so we just save that data for later stages to retrieve # init doesn't have metrics to update so we just save that data for later stages to retrieve
if metrics is not None: if metrics is not None:
self.update_metrics(stage, metrics) self.update_metrics(stage, metrics)
class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero2"
ZERO_DP_3 = "zero3"
OFFLOAD = "offload"
...@@ -25,7 +25,7 @@ from .file_utils import ( ...@@ -25,7 +25,7 @@ from .file_utils import (
is_torch_tpu_available, is_torch_tpu_available,
torch_required, torch_required,
) )
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption
from .utils import logging from .utils import logging
...@@ -236,9 +236,22 @@ class TrainingArguments: ...@@ -236,9 +236,22 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have. step can take a long time) but will not yield the same results as the interrupted training would have.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature. training only). This is an experimental feature.
A list of options along the following:
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
to ZeRO-2.
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`).
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`).
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`).
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`.
deepspeed (:obj:`str`, `optional`): deepspeed (:obj:`str`, `optional`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``). evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
...@@ -443,9 +456,14 @@ class TrainingArguments: ...@@ -443,9 +456,14 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." "help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
}, },
) )
sharded_ddp: bool = field( sharded_ddp: str = field(
default=False, default="",
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, metadata={
"choices": ["simple", "zero_dp_2", "zero_dp_3", "zero_dp_2 offload", "zero_dp_3 offload"],
"help": "Whether or not to use sharded DDP training (in distributed training only). The base option "
"should be `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` "
"like this: zero_dp_2 offload` or `zero_dp_3 offload`",
},
) )
deepspeed: Optional[str] = field( deepspeed: Optional[str] = field(
default=None, default=None,
...@@ -535,6 +553,20 @@ class TrainingArguments: ...@@ -535,6 +553,20 @@ class TrainingArguments:
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training" "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
) )
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
def __repr__(self): def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once # We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5) # those deprecated arguments are removed form TrainingArguments. (TODO: v5)
...@@ -662,7 +694,7 @@ class TrainingArguments: ...@@ -662,7 +694,7 @@ class TrainingArguments:
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). - :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`). - :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses - :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
:obj:`torch.nn.DistributedDataParallel`). :obj:`torch.nn.DistributedDataParallel`).
- :obj:`ParallelMode.TPU`: several TPU cores. - :obj:`ParallelMode.TPU`: several TPU cores.
""" """
...@@ -692,6 +724,8 @@ class TrainingArguments: ...@@ -692,6 +724,8 @@ class TrainingArguments:
for k, v in d.items(): for k, v in d.items():
if isinstance(v, Enum): if isinstance(v, Enum):
d[k] = v.value d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
return d return d
def to_json_string(self): def to_json_string(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment