Unverified Commit 2df34f4a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer] deepspeed integration (#9211)



* deepspeed integration

* style

* add test

* ds wants to do its own backward

* fp16 assert

* Update src/transformers/training_args.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* style

* for clarity extract what args are being passed to deepspeed

* introduce the concept of self.wrapped_model

* s/self.wrapped_model/self.model_wrapped/

* complete transition to self.wrapped_model / self.model

* fix

* doc

* give ds its own init

* add custom overrides, handle bs correctly

* fix test

* clean up model_init logic, fix small bug

* complete fix

* collapse --deepspeed_config into --deepspeed

* style

* start adding doc notes

* style

* implement hf2ds optimizer and scheduler configuration remapping

* oops

* call get_num_training_steps absolutely when needed

* workaround broken auto-formatter

* deepspeed_config arg is no longer needed - fixed in deepspeed master

* use hf's fp16 args in config

* clean

* start on the docs

* rebase cleanup

* finish up --fp16

* clarify the supported stages

* big refactor thanks to discovering deepspeed.init_distributed

* cleanup

* revert fp16 part

* add checkpoint-support

* more init ds into integrations

* extend docs

* cleanup

* unfix docs

* clean up old code

* imports

* move docs

* fix logic

* make it clear which file it's referring to

* document nodes/gpus

* style

* wrong format

* style

* deepspeed handles gradient clipping

* easier to read

* major doc rewrite

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* docs

* switch to AdamW optimizer

* style

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* clarify doc
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 5f672103
..
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
......@@ -96,3 +96,438 @@ TFTrainingArguments
.. autoclass:: transformers.TFTrainingArguments
:members:
Trainer Integrations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The :class:`~transformers.Trainer` has been extended to support libraries that may dramatically improve your training
time and fit much bigger models.
Currently it supports third party solutions, `DeepSpeed <https://github.com/microsoft/DeepSpeed>`__ and `FairScale
<https://github.com/facebookresearch/fairscale/>`__, which implement parts of the paper `ZeRO: Memory Optimizations
Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He
<https://arxiv.org/abs/1910.02054>`__.
This provided support is new and experimental as of this writing.
You will need at least 2 GPUs to benefit from these features.
FairScale
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
By integrating `FairScale <https://github.com/facebookresearch/fairscale/>`__ the :class:`~transformers.Trainer`
provides support for the following features from `the ZeRO paper <https://arxiv.org/abs/1910.02054>`__:
1. Optimizer State Sharding
2. Gradient Sharding
To deploy this feature:
1. Install the library via pypi:
.. code-block:: bash
pip install fairscale
or find more details on `the FairScale's github page
<https://github.com/facebookresearch/fairscale/#installation>`__.
2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
For example here is how you could use it for ``finetune_trainer.py`` with 2 GPUs:
.. code-block:: bash
cd examples/seq2seq
python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py \
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
--output_dir output_dir --overwrite_output_dir \
--do_train --n_train 500 --num_train_epochs 1 \
--per_device_train_batch_size 1 --freeze_embeds \
--src_lang en_XX --tgt_lang ro_RO --task translation \
--fp16 --sharded_ddp
Notes:
- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with ``--fp16`` too, to make things even faster.
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
significantly shorter training time.
DeepSpeed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
`DeepSpeed <https://github.com/microsoft/DeepSpeed>`__ implements everything described in the `ZeRO paper
<https://arxiv.org/abs/1910.02054>`__, except ZeRO's stage 3. "Parameter Partitioning (Pos+g+p)". Currently it provides
full support for:
1. Optimizer State Partitioning (ZeRO stage 1)
2. Add Gradient Partitioning (ZeRO stage 2)
To deploy this feature:
1. Install the library via pypi:
.. code-block:: bash
pip install deepspeed
or find more details on `the DeepSpeed's github page <https://github.com/microsoft/deepspeed#installation>`__.
2. Adjust the :class:`~transformers.Trainer` command line arguments as following:
1. replace ``python -m torch.distributed.launch`` with ``deepspeed``.
2. add a new argument ``--deepspeed ds_config.json``, where ``ds_config.json`` is the DeepSpeed configuration file
as documented `here <https://www.deepspeed.ai/docs/config-json/>`__. The file naming is up to you.
Therefore, if your original command line looked as following:
.. code-block:: bash
python -m torch.distributed.launch --nproc_per_node=2 your_program.py <normal cl args>
Now it should be:
.. code-block:: bash
deepspeed --num_gpus=2 your_program.py <normal cl args> --deepspeed ds_config.json
Unlike, ``torch.distributed.launch`` where you have to specify how many GPUs to use with ``--nproc_per_node``, with
the ``deepspeed`` launcher you don't have to use the corresponding ``--num_gpus`` if you want all of your GPUs used.
The full details on how to configure various nodes and GPUs can be found `here
<https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node>`__.
Here is an example of running ``finetune_trainer.py`` under DeepSpeed deploying all available GPUs:
.. code-block:: bash
cd examples/seq2seq
deepspeed ./finetune_trainer.py --deepspeed ds_config.json \
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
--output_dir output_dir --overwrite_output_dir \
--do_train --n_train 500 --num_train_epochs 1 \
--per_device_train_batch_size 1 --freeze_embeds \
--src_lang en_XX --tgt_lang ro_RO --task translation
Note that in the DeepSpeed documentation you are likely to see ``--deepspeed --deepspeed_config ds_config.json`` -
i.e. two DeepSpeed-related arguments, but for the sake of simplicity, and since there are already so many arguments
to deal with, we combined the two into a single argument.
Before you can deploy DeepSpeed, let's discuss its configuration.
**Configuration:**
For the complete guide to the DeepSpeed configuration options that can be used in its configuration file please refer
to the `following documentation <https://www.deepspeed.ai/docs/config-json/>`__.
While you always have to supply the DeepSpeed configuration file, you can configure the DeepSpeed integration in
several ways:
1. Supply most of the configuration inside the file, and just use a few required command line arguments. This is the
recommended way as it puts most of the configuration params in one place.
2. Supply just the ZeRO configuration params inside the file, and configure the rest using the normal
:class:`~transformers.Trainer` command line arguments.
3. Any variation of the first two ways.
To get an idea of what DeepSpeed configuration file looks like, here is one that activates ZeRO stage 2 features,
enables FP16, uses AdamW optimizer and WarmupLR scheduler:
.. code-block:: json
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"cpu_offload": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [ 0.8, 0.999 ],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"zero_allow_untested_optimizer": true,
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
}
}
If you already have a command line that you have been using with :class:`transformers.Trainer` args, you can continue
using those and the :class:`~transformers.Trainer` will automatically convert them into the corresponding DeepSpeed
configuration at run time. For example, you could use the following configuration file:
.. code-block:: json
{
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"cpu_offload": true
}
}
and the following command line arguments:
.. code-block:: bash
--learning_rate 3e-5 --warmup_steps 500 --adam_beta1 0.8 --adam_beta2 0.999 --adam_epsilon 1e-8 \
--weight_decay 3e-7 --lr_scheduler_type constant_with_warmup --fp16 --fp16_backend amp
to achieve the same configuration as provided by the longer json file in the first example.
When you execute the program, DeepSpeed will log the configuration it received from the :class:`~transformers.Trainer`
to the console, so you can see exactly what the final configuration was passed to it.
**Shared Configuration:**
Some configuration information is required by both the :class:`~transformers.Trainer` and DeepSpeed to function
correctly, therefore, to prevent conflicting definitions, which could lead to hard to detect errors, we chose to
configure those via the :class:`~transformers.Trainer` command line arguments.
Therefore, the following DeepSpeed configuration params shouldn't be used with the :class:`~transformers.Trainer`:
* ``train_batch_size``
* ``train_micro_batch_size_per_gpu``
* ``gradient_accumulation_steps``
as these will be automatically derived from the run time environment and the following 2 command line arguments:
.. code-block:: bash
--per_device_train_batch_size 8 --gradient_accumulation_steps 2
which are always required to be supplied.
Of course, you will need to adjust the values in this example to your situation.
**ZeRO:**
The ``zero_optimization`` section of the configuration file is the most important part (`docs
<https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training>`__), since that is where you define
which ZeRO stages you want to enable and how to configure them.
.. code-block:: json
{
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true,
"cpu_offload": true
}
}
Notes:
- enabling ``cpu_offload`` should reduce GPU RAM usage (it requires ``"stage": 2``)
- ``"overlap_comm": true`` trades off increased GPU RAM usage to lower all-reduce latency. ``overlap_comm`` uses 4.5x
the ``allgather_bucket_size`` and ``reduce_bucket_size`` values. So if they are set to 5e8, this requires a 9GB
footprint (``5e8 x 2Bytes x 2 x 4.5``). Therefore, if you have a GPU with 8GB or less RAM, to avoid getting
OOM-errors you will need to reduce those parameters to about ``2e8``, which would require 3.6GB.
This section has to be configured exclusively via DeepSpeed configuration - the :class:`~transformers.Trainer` provides
no equivalent command line arguments.
**Optimizer:**
DeepSpeed's main optimizers are Adam, OneBitAdam, and Lamb. These have been thoroughly tested with ZeRO and are thus
recommended to be used. It, however, can import other optimizers from ``torch``. The full documentation is `here
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`__.
If you don't configure the ``optimizer`` entry in the configuration file, the :class:`~transformers.Trainer` will
automatically set it to ``AdamW`` and will use the supplied values or the defaults for the following command line
arguments: ``--learning_rate``, ``--adam_beta1``, ``--adam_beta2``, ``--adam_epsilon`` and ``--weight_decay``.
Here is an example of the pre-configured ``optimizer`` entry for AdamW:
.. code-block:: json
{
"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 0.001,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
}
}
Since AdamW isn't on the list of tested with DeepSpeed/ZeRO optimizers, we have to add
``zero_allow_untested_optimizer`` flag.
If you want to use one of the officially supported optimizers, configure them explicitly in the configuration file, and
make sure to adjust the values. e.g. if use Adam you will want ``weight_decay`` around ``0.01``.
**Scheduler:**
DeepSpeed supports LRRangeTest, OneCycle, WarmupLR and WarmupDecayLR LR schedulers. The full documentation is `here
<https://www.deepspeed.ai/docs/config-json/#scheduler-parameters>`__.
If you don't configure the ``scheduler`` entry in the configuration file, the :class:`~transformers.Trainer` will use
the value of ``--lr_scheduler_type`` to configure it. Currently the :class:`~transformers.Trainer` supports only 2 LR
schedulers that are also supported by DeepSpeed:
* ``WarmupLR`` via ``--lr_scheduler_type constant_with_warmup``
* ``WarmupDecayLR`` via ``--lr_scheduler_type linear``. This is also the default value for ``--lr_scheduler_type``,
therefore, if you don't configure the scheduler this is scheduler that will get configured by default.
In either case, the values of ``--learning_rate`` and ``--warmup_steps`` will be used for the configuration.
In other words, if you don't use the configuration file to set the ``scheduler`` entry, provide either:
.. code-block:: bash
--lr_scheduler_type constant_with_warmup --learning_rate 3e-5 --warmup_steps 500
or
.. code-block:: bash
--lr_scheduler_type linear --learning_rate 3e-5 --warmup_steps 500
with the desired values. If you don't pass these arguments, reasonable default values will be used instead.
In the case of WarmupDecayLR ``total_num_steps`` gets set either via the ``--max_steps`` command line argument, or if
it is not provided, derived automatically at run time based on the environment and the size of the dataset and other
command line arguments.
Here is an example of the pre-configured ``scheduler`` entry for WarmupLR (``constant_with_warmup`` in the
:class:`~transformers.Trainer` API):
.. code-block:: json
{
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
}
}
**Automatic Mixed Precision:**
You can work with FP16 in one of the following ways:
1. Pytorch native amp, as documented `here <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`__.
2. NVIDIA's apex, as documented `here
<https://www.deepspeed.ai/docs/config-json/#automatic-mixed-precision-amp-training-options>`__.
If you want to use an equivalent of the pytorch native amp, you can either configure the ``fp16`` entry in the
configuration file, or use the following command line arguments: ``--fp16 --fp16_backend amp``.
Here is an example of the ``fp16`` configuration:
.. code-block:: json
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
}
If you want to use NVIDIA's apex instead, you can can either configure the ``amp`` entry in the configuration file, or
use the following command line arguments: ``--fp16 --fp16_backend apex --fp16_opt_level 01``.
Here is an example of the ``amp`` configuration:
.. code-block:: json
{
"amp": {
"enabled": true,
"opt_level": "O1"
}
}
**Gradient Clipping:**
If you don't configure the ``gradient_clipping`` entry in the configuration file, the :class:`~transformers.Trainer`
will use the value of the ``--max_grad_norm`` command line argument to set it.
Here is an example of the ``gradient_clipping`` configuration:
.. code-block:: json
{
"gradient_clipping": 1.0,
}
**Notes:**
* DeepSpeed works with the PyTorch :class:`~transformers.Trainer` but not TF :class:`~transformers.TFTrainer`.
* While DeepSpeed has a pip installable PyPI package, it is highly recommended that it gets installed from `source
<https://github.com/microsoft/deepspeed#installation>`__ to best match your hardware and also if you need to enable
certain features, like 1-bit Adam, which aren't available in the pypi distribution.
* You don't have to use the :class:`~transformers.Trainer` to use DeepSpeed with HuggingFace ``transformers`` - you can
use any model with your own trainer, and you will have to adapt the latter according to `the DeepSpeed integration
instructions <https://www.deepspeed.ai/getting-started/#writing-deepspeed-models>`__.
**Main DeepSpeed Resources:**
- `github <https://github.com/microsoft/deepspeed>`__
- `Usage docs <https://www.deepspeed.ai/getting-started/>`__
- `API docs <https://deepspeed.readthedocs.io/en/latest/index.html>`__
Finally, please, remember that, HuggingFace :class:`~transformers.Trainer` only integrates DeepSpeed, therefore if you
have any problems or questions with regards to DeepSpeed usage, please, file an issue with `DeepSpeed github
<https://github.com/microsoft/DeepSpeed/issues>`__.
......@@ -278,45 +278,6 @@ pass it to the trainer.
Finally, you can view the results, including any calculated metrics, by launching tensorboard in your specified
``logging_dir`` directory.
Trainer Integrations
-----------------------------------------------------------------------------------------------------------------------
The trainer is being extended to support experimental libraries that may dramatically improve your training time and
fit bigger models.
The main part that is being integrated at the moment is based on the paper `ZeRO: Memory Optimizations Toward Training
Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He
<https://arxiv.org/abs/1910.02054>`__.
You can already deploy the following features from this paper:
* Optimizer State Sharding
* Gradient Sharding
using the `--sharded_ddp` trainer argument. This is implemented via `fairscale
<https://github.com/facebookresearch/fairscale/>`__, so you will have to install this library.
This feature requires distributed training (so multiple GPUs) and is not implemented for TPUs.
For example here is how you could use it for `finetune_trainer.py`:
.. code-block:: bash
cd examples/seq2seq
python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py \
--model_name_or_path sshleifer/distill-mbart-en-ro-12-4 --data_dir wmt_en_ro \
--output_dir output_dir --overwrite_output_dir \
--do_train --n_train 500 --num_train_epochs 1 \
--per_device_train_batch_size 1 --freeze_embeds \
--src_lang en_XX --tgt_lang ro_RO --task translation \
--fp16 --sharded_ddp
Note that it works with `--fp16` too, to make things even faster.
One of the main benefits of enabling `--sharded_ddp` is that it uses a lot less GPU memory, so you should be able to
use significantly larger batch sizes using the same hardware (e.g. 3x or bigger).
Eventually more parts will be supported via integrating `DeepSpeed <https://github.com/microsoft/DeepSpeed>`__.
.. _additional-resources:
......
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true,
"cpu_offload": true
},
"zero_allow_untested_optimizer": true,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
......@@ -18,7 +18,7 @@ import unittest
from unittest.mock import patch
from transformers.file_utils import is_apex_available
from transformers.integrations import is_fairscale_available
from transformers.integrations import is_deepspeed_available, is_fairscale_available
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
......@@ -49,6 +49,17 @@ def require_fairscale(test_case):
return test_case
# a candidate for testing_utils
def require_deepspeed(test_case):
"""
Decorator marking a test that requires deepspeed
"""
if not is_deepspeed_available():
return unittest.skip("test requires deepspeed")(test_case)
else:
return test_case
# a candidate for testing_utils
def require_apex(test_case):
"""
......@@ -61,8 +72,8 @@ def require_apex(test_case):
class TestFinetuneTrainer(TestCasePlus):
def finetune_trainer_quick(self, distributed=None, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
def finetune_trainer_quick(self, distributed=None, deepspeed=False, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, deepspeed, extra_args_str)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
......@@ -96,6 +107,11 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_apex(self):
self.finetune_trainer_quick(extra_args_str="--fp16 --fp16_backend=apex")
@require_torch_multi_gpu
@require_deepspeed
def test_finetune_trainer_deepspeed(self):
self.finetune_trainer_quick(deepspeed=True)
@slow
def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere
......@@ -125,6 +141,7 @@ class TestFinetuneTrainer(TestCasePlus):
model_name: str,
num_train_epochs: int,
distributed: bool = False,
deepspeed: bool = False,
extra_args_str: str = None,
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
......@@ -164,7 +181,15 @@ class TestFinetuneTrainer(TestCasePlus):
if extra_args_str is not None:
args.extend(extra_args_str.split())
if distributed:
if deepspeed:
ds_args = f"--deepspeed {self.test_file_dir_str}/ds_config.json".split()
distributed_args = f"""
{self.test_file_dir}/finetune_trainer.py
""".split()
cmd = ["deepspeed"] + distributed_args + args + ds_args
execute_subprocess_async(cmd, env=self.get_env())
elif distributed:
n_gpu = get_gpu_count()
distributed_args = f"""
-m torch.distributed.launch
......@@ -173,6 +198,7 @@ class TestFinetuneTrainer(TestCasePlus):
""".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
else:
testargs = ["finetune_trainer.py"] + args
with patch.object(sys, "argv", testargs):
......
......@@ -15,13 +15,17 @@
Integrations with other Python libraries.
"""
import importlib.util
import io
import json
import math
import numbers
import os
import re
import tempfile
from pathlib import Path
from types import SimpleNamespace
from .trainer_utils import SchedulerType
from .utils import logging
......@@ -43,7 +47,6 @@ if _has_comet:
except (ImportError, ValueError):
_has_comet = False
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available # noqa: E402
from .trainer_callback import TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy # noqa: E402
......@@ -94,6 +97,10 @@ def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None
def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
......@@ -230,6 +237,157 @@ def rewrite_logs(d):
return new_d
def init_deepspeed(trainer, num_training_steps):
"""
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration
Args:
trainer: Trainer object
num_training_steps: per single gpu
Returns: model, optimizer, lr_scheduler
"""
import deepspeed
args = trainer.args
ds_config_file = args.deepspeed
model = trainer.model
with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
# The following code translates relevant trainer's cl args into the DS config
# First to ensure that there is no mismatch between cl args values and presets in the config
# file, ask to not set in ds config file:
# - "train_batch_size",
# - "train_micro_batch_size_per_gpu",
# - "gradient_accumulation_steps"
bs_keys = ["train_batch_size", "train_micro_batch_size_per_gpu"]
if len([x for x in bs_keys if x in config.keys()]):
raise ValueError(
f"Do not include {bs_keys} entries in the ds config file, as they will be set via --per_device_train_batch_size or its default"
)
if "gradient_accumulation_steps" in config.keys():
raise ValueError(
"Do not include gradient_accumulation_steps entries in the ds config file, as they will be set via --gradient_accumulation_steps or its default"
)
# DeepSpeed does:
# train_batch_size = n_gpus * train_micro_batch_size_per_gpu * gradient_accumulation_steps
# therefore we just need to set:
config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
if "gradient_clipping" in config:
logger.info(
f"Keeping the `gradient_clipping` config from {ds_config_file} intact, ignoring any gradient clipping-specific cl args"
)
else: # override only if the ds config doesn't already have this section
config["gradient_clipping"] = args.max_grad_norm
if "optimizer" in config:
logger.info(
f"Keeping the `optimizer` config from {ds_config_file} intact, ignoring any optimizer-specific cl args"
)
else: # override only if the ds config doesn't already have this section
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
# But trainer uses AdamW by default.
# To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer`
optimizer_configs = {
"AdamW": {
"lr": args.learning_rate,
"betas": [args.adam_beta1, args.adam_beta2],
"eps": args.adam_epsilon,
"weight_decay": args.weight_decay,
}
}
optimizer = "AdamW"
config["zero_allow_untested_optimizer"] = True
config["optimizer"] = {
"type": optimizer,
"params": optimizer_configs[optimizer],
}
# DS schedulers (deepspeed/runtime/lr_schedules.py):
#
# DS name | --lr_scheduler_type | HF func | Notes
# -------------| ---------------------|-----------------------------------|--------------------
# LRRangeTest | na | na | LRRT
# OneCycle | na | na | 1CLR
# WarmupLR | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
if "scheduler" in config:
logger.info(
f"Keeping the `scheduler` config from {ds_config_file} intact, ignoring any scheduler-specific cl args"
)
else: # override only if the ds config doesn't already have this section
if args.lr_scheduler_type == SchedulerType.LINEAR:
scheduler = "WarmupDecayLR"
params = {
"last_batch_iteration": -1,
"total_num_steps": num_training_steps,
"warmup_min_lr": 0,
"warmup_max_lr": args.learning_rate,
"warmup_num_steps": args.warmup_steps,
}
elif args.lr_scheduler_type == SchedulerType.CONSTANT_WITH_WARMUP:
scheduler = "WarmupLR"
params = {
"warmup_min_lr": 0,
"warmup_max_lr": args.learning_rate,
"warmup_num_steps": args.warmup_steps,
}
else:
raise ValueError(f"{args.lr_scheduler_type} scheduler type is not supported by DeepSpeed")
config["scheduler"] = {
"type": scheduler,
"params": params,
}
# fp16
if trainer.fp16_backend is not None:
# Deepspeed has 2 possible fp16 config entries:
# - `fp16`: for the native amp - it has a bunch of optional params but we won't set any here unless the user did the work
# - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
if trainer.fp16_backend == "apex":
if "amp" in config:
logger.info(
f"Keeping the `amp` config from {ds_config_file} intact, ignoring any amp-specific cl args"
)
else:
config["amp"] = {
"enabled": True,
"opt_level": args.fp16_opt_level,
}
elif trainer.fp16_backend == "amp":
if "fp16" in config:
logger.info(
f"Keeping the `fp16` config from {ds_config_file} intact, ignoring any fp16-specific cl args"
)
else:
config["fp16"] = {
"enabled": True,
}
# for clarity extract the specific cl args that are being passed to deepspeed
ds_args = dict(local_rank=args.local_rank)
# init that takes part of the config via `args`, and the bulk of it via `config_params`
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
args=SimpleNamespace(**ds_args), # expects an obj
model=model,
model_parameters=model_parameters,
config_params=config,
)
return model, optimizer, lr_scheduler
class TensorBoardCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
......
......@@ -42,6 +42,7 @@ from .integrations import ( # isort: split
is_wandb_available,
run_hp_search_optuna,
run_hp_search_ray,
init_deepspeed,
)
import numpy as np
......@@ -252,6 +253,7 @@ class Trainer:
# Seed must be set before instantiating the model when using model
set_seed(self.args.seed)
self.hp_name = None
self.deepspeed = None
if model is None:
if model_init is not None:
......@@ -338,20 +340,25 @@ class Trainer:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif args.deepspeed:
raise ValueError("can't use --sharded_ddp together with --deepspeed.")
else:
self.sharded_dpp = True
# Mixed precision setup
self.use_apex = False
self.use_amp = False
self.fp16_backend = None
if args.fp16:
if args.fp16_backend == "auto":
backend = "amp" if _is_native_amp_available else "apex"
self.fp16_backend = "amp" if _is_native_amp_available else "apex"
else:
backend = args.fp16_backend
logger.info(f"Using {backend} fp16 backend")
self.fp16_backend = args.fp16_backend
logger.info(f"Using {self.fp16_backend} fp16 backend")
if backend == "amp":
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
else:
......@@ -714,7 +721,16 @@ class Trainer:
num_train_epochs = 1
num_update_steps_per_epoch = max_steps
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
self.model_wrapped = model # will get further wrapped in DDP
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
self.state = TrainerState()
self.state.is_hyper_param_search = trial is not None
......@@ -878,7 +894,9 @@ class Trainer:
and (step + 1) == steps_in_epoch
):
# Gradient clipping
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0:
if self.args.max_grad_norm is not None and self.args.max_grad_norm > 0 and not self.deepspeed:
# deepspeed does its own clipping
if self.use_amp:
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
......@@ -945,6 +963,11 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
if self.deepspeed:
self.deepspeed.load_checkpoint(
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
......@@ -1006,18 +1029,23 @@ class Trainer:
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
self.store_flos()
self.save_model(output_dir)
if self.deepspeed:
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
if self.sharded_dpp:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
elif self.is_world_process_zero() and not self.deepspeed:
# deepspeed.save_checkpoint above saves model/optim/sched
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
......@@ -1049,10 +1077,11 @@ class Trainer:
def _load_optimizer_and_scheduler(self, model_path):
"""If optimizer and scheduler states exist, load them."""
if (
model_path is not None
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
if model_path is None:
return
if os.path.isfile(os.path.join(model_path, "optimizer.pt")) and os.path.isfile(
os.path.join(model_path, "scheduler.pt")
):
# Load in optimizer and scheduler states
if is_torch_tpu_available():
......@@ -1075,6 +1104,10 @@ class Trainer:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
if self.deepspeed:
# Not sure how to check if there is a saved deepspeed checkpoint, but since it just return None if it fails to find a deepspeed checkpoint this is sort of a check-n-load function
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True)
def hyperparameter_search(
self,
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
......@@ -1227,6 +1260,9 @@ class Trainer:
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.deepspeed:
# calling on DS engine (model_wrapped == DDP(Deepspeed(PretrainedModule)))
self.model_wrapped.module.backward(loss)
else:
loss.backward()
......
......@@ -217,6 +217,9 @@ class TrainingArguments:
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
deepspeed (:obj:`str`, `optional`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
label_smoothing_factor (:obj:`float`, `optional`, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to :obj:`label_smoothing_factor/num_labels` and :obj:`1 -
......@@ -394,6 +397,10 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
)
deepspeed: Optional[str] = field(
default=None,
metadata={"help": "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json)"},
)
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
......@@ -480,7 +487,21 @@ class TrainingArguments:
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
torch.distributed.init_process_group(backend="nccl")
#
# deepspeed performs its own DDP internally, and requires the program to be started with:
# deepspeed ./program.py
# rather than:
# python -m torch.distributed.launch --nproc_per_node=2 ./program.py
if self.deepspeed:
from .integrations import is_deepspeed_available
if not is_deepspeed_available():
raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
import deepspeed
deepspeed.init_distributed()
else:
torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank)
n_gpu = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment