Unverified Commit 7ec596ec authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[DeepSpeed] decouple `DeepSpeedConfigHF` from `Trainer` (#11966)



* decouple DeepSpeedConfigHF from Trainer

* add LoggingLevel ctx manager; add new test

* cleanup

* add docs

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* implemented suggested renames

* formatter workaround
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 1c3ab3e5
...@@ -468,6 +468,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -468,6 +468,7 @@ Flax), PyTorch, and/or TensorFlow.
main_classes/processors main_classes/processors
main_classes/tokenizer main_classes/tokenizer
main_classes/trainer main_classes/trainer
main_classes/deepspeed
main_classes/feature_extractor main_classes/feature_extractor
.. toctree:: .. toctree::
......
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
HfDeepSpeedConfig
-----------------------------------------------------------------------------------------------------------------------
The :class:`~transformers.integrations.HfDeepSpeedConfig` is used to integrate Deepspeed into the 🤗 Transformer core
functionality, when :class:`~transformers.Trainer` is not used.
When using :class:`~transformers.Trainer` everything is automatically taken care of.
When not using :class:`~transformers.Trainer`, to efficiently deploy DeepSpeed stage 3, you must instantiate the
:class:`~transformers.integrations.HfDeepSpeedConfig` object before instantiating the model.
For example for a pretrained model:
.. code-block:: python
from transformers.integrations import HfDeepSpeedConfig
from transformers import AugoModel
ds_config = { ... } # deepspeed config object or path to the file
# must run before instantiating the model
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
model = AutoModel.from_pretrained("gpt2")
engine = deepspeed.initialize(model=model, config_params=ds_config, ...)
or for non-pretrained model:
.. code-block:: python
from transformers.integrations import HfDeepSpeedConfig
from transformers import AugoModel, AutoConfig
ds_config = { ... } # deepspeed config object or path to the file
# must run before instantiating the model
dschf = HfDeepSpeedConfig(ds_config) # keep this object alive
config = AutoConfig.from_pretrained("gpt2")
model = AutoModel.from_config(config)
engine = deepspeed.initialize(model=model, config_params=ds_config, ...)
HfDeepSpeedConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.integrations.HfDeepSpeedConfig
:members:
...@@ -286,28 +286,57 @@ def _set_if_auto(config, key, val): ...@@ -286,28 +286,57 @@ def _set_if_auto(config, key, val):
config[key] = val config[key] = val
class DeepSpeedConfigHF: class HfDeepSpeedConfig:
""" """
This object contains Deepspeed configuration and can be quickly queried for things like zero stage. This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
We store a ``weakref`` of this object in the module's global to be able to access the config from areas where the A ``weakref`` of this object is stored in the module's globals to be able to access the config from areas where
Trainer is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). things like the Trainer object is not available (e.g. ``from_pretrained`` and ``_get_resized_embeddings``).
Therefore it's important that this object remains alive while the program is still running.
:class:`~transformers.Trainer` uses the ``HfTrainerDeepSpeedConfig`` subclass instead. That subclass has logic to
sync the configuration with values of :class:`~transformers.TrainingArguments` by replacing special placeholder
values: ``"auto"``. Without this special logic the DeepSpeed configuration is not modified in any way.
Args:
config_file_or_dict (:obj:`Union[str, Dict]`) - path to DeepSpeed config file or dict.
The ``DeepSpeedConfigHF`` object is meant to be created during ``TrainingArguments`` object creation and has the
same lifespan as the latter.
""" """
def __init__(self, args): def __init__(self, config_file_or_dict):
self.config = None # set global weakref object
self.stage = 0 set_hf_deepspeed_config(self)
self.offload = False
dep_version_check("deepspeed") dep_version_check("deepspeed")
self.config_process(args) if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overriden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a DeepSpeed config file or a pre-populated dict")
self.config = config
# set global weakref object # zero stage - this is done as early as possible, before model is created, to allow
deepspeed_config_hf_set(self) # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
# during ``zero.Init()`` which needs whether fp16 is enabled, dtype, etc.
config_zero = config.get("zero_optimization", {})
self.stage = config_zero.get("stage", 0)
# offload
self.offload = False
config_zero = config.get("zero_optimization", {})
if self.is_zero2():
self.offload = _is_true(config_zero, "cpu_offload")
elif self.is_zero3():
offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True
if config_zero.get("offload_param", {}).get("device") in offload_devices:
self.offload = True
def is_zero2(self): def is_zero2(self):
return self.stage == 2 return self.stage == 2
...@@ -318,28 +347,23 @@ class DeepSpeedConfigHF: ...@@ -318,28 +347,23 @@ class DeepSpeedConfigHF:
def is_offload(self): def is_offload(self):
return self.offload return self.offload
def config_process(self, args):
"""
1. load json if the ``args.deepspeed`` is a path
2. replace any ``auto`` values in the config with the correct or recommended value
This is done as early as possible, before model is created, to allow ``is_deepspeed_zero3_enabled`` query and class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
getting to the early deepspeed config object during ``zero.Init()`` which needs whether fp16 is enabled, dtype, """
etc. The ``HfTrainerDeepSpeedConfig`` object is meant to be created during ``TrainingArguments`` object creation and has
the same lifespan as the latter.
""" """
config_file_or_dict = args.deepspeed
if isinstance(config_file_or_dict, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since `auto` values would have been overriden
config = deepcopy(config_file_or_dict)
elif isinstance(config_file_or_dict, str):
with io.open(config_file_or_dict, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
self.config = config def __init__(self, config_file_or_dict):
super().__init__(config_file_or_dict)
def trainer_config_process(self, args):
"""
Adjust the config with ``TrainingArguments`` values. This stage is run during ``TrainingArguments`` object
creation.
"""
config = self.config
# DeepSpeed does: # DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
...@@ -349,10 +373,6 @@ class DeepSpeedConfigHF: ...@@ -349,10 +373,6 @@ class DeepSpeedConfigHF:
_set_if_auto(config, "train_batch_size", train_batch_size) _set_if_auto(config, "train_batch_size", train_batch_size)
_set_if_auto(config, "gradient_clipping", args.max_grad_norm) _set_if_auto(config, "gradient_clipping", args.max_grad_norm)
# zero
config_zero = config.get("zero_optimization", {})
self.stage = config_zero.get("stage", 0)
config_optim = config.get("optimizer", {}) config_optim = config.get("optimizer", {})
if config_optim != {}: if config_optim != {}:
config_optim_params = config_optim.get("params") config_optim_params = config_optim.get("params")
...@@ -367,7 +387,7 @@ class DeepSpeedConfigHF: ...@@ -367,7 +387,7 @@ class DeepSpeedConfigHF:
_set_if_auto(config_sched_params, "warmup_min_lr", 0) _set_if_auto(config_sched_params, "warmup_min_lr", 0)
_set_if_auto(config_sched_params, "warmup_max_lr", args.learning_rate) _set_if_auto(config_sched_params, "warmup_max_lr", args.learning_rate)
_set_if_auto(config_sched_params, "warmup_num_steps", args.warmup_steps) _set_if_auto(config_sched_params, "warmup_num_steps", args.warmup_steps)
# total_num_steps - will get set in deepspeed_init # total_num_steps - will get set in trainer_config_finalize
# fp16 # fp16
if args.fp16: if args.fp16:
...@@ -381,27 +401,16 @@ class DeepSpeedConfigHF: ...@@ -381,27 +401,16 @@ class DeepSpeedConfigHF:
_set_if_auto(config_fp16, "enabled", fp16_backend == "amp") _set_if_auto(config_fp16, "enabled", fp16_backend == "amp")
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features, so probably best to be avoided. # ZeRO features
config_amp = config.get("amp") config_amp = config.get("amp")
_set_if_auto(config_amp, "enabled", fp16_backend == "apex") _set_if_auto(config_amp, "enabled", fp16_backend == "apex")
_set_if_auto(config_amp, "opt_level", args.fp16_opt_level) _set_if_auto(config_amp, "opt_level", args.fp16_opt_level)
config_zero = config.get("zero_optimization", {}) def trainer_config_finalize(self, args, model, num_training_steps):
if self.is_zero2():
self.offload = _is_true(config_zero, "cpu_offload")
elif self.is_zero3():
offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
self.offload = True
if config_zero.get("offload_param", {}).get("device") in offload_devices:
self.offload = True
def config_finalize(self, args, model, num_training_steps):
""" """
This stage is run after we have the model and know num_training_steps. This stage is run after we have the model and know num_training_steps.
Now we we can complete the configuration process. Now we we can complete the configuration process.
""" """
config = self.config config = self.config
...@@ -421,27 +430,27 @@ class DeepSpeedConfigHF: ...@@ -421,27 +430,27 @@ class DeepSpeedConfigHF:
# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
_deepspeed_config_hf_weak_ref = None _hf_deepspeed_config_weak_ref = None
def deepspeed_config_hf_set(deepspeed_config_hf_obj): def set_hf_deepspeed_config(hf_deepspeed_config_obj):
# this is a special weakref global object to allow us to get to Deepspeed config from APIs # this is a special weakref global object to allow us to get to Deepspeed config from APIs
# that don't have an easy way to get to the Deepspeed config outside of the Trainer domain. # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
global _deepspeed_config_hf_weak_ref global _hf_deepspeed_config_weak_ref
# will go away automatically when DeepSpeedConfigHF is destroyed (when TrainingArguments is destroyed) # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed)
_deepspeed_config_hf_weak_ref = weakref.ref(deepspeed_config_hf_obj) _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
def is_deepspeed_zero3_enabled(): def is_deepspeed_zero3_enabled():
if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None: if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _deepspeed_config_hf_weak_ref().is_zero3() return _hf_deepspeed_config_weak_ref().is_zero3()
else: else:
return False return False
def deepspeed_config(): def deepspeed_config():
if _deepspeed_config_hf_weak_ref is not None and _deepspeed_config_hf_weak_ref() is not None: if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _deepspeed_config_hf_weak_ref().config return _hf_deepspeed_config_weak_ref().config
else: else:
return None return None
...@@ -464,11 +473,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): ...@@ -464,11 +473,11 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
model = trainer.model model = trainer.model
deepspeed_config_hf = trainer.args.deepspeed_config_hf hf_deepspeed_config = trainer.args.hf_deepspeed_config
deepspeed_config_hf.config_finalize(trainer.args, model, num_training_steps) hf_deepspeed_config.trainer_config_finalize(trainer.args, model, num_training_steps)
# resume config update - some bits like `model` and `num_training_steps` only become available during train # resume config update - some bits like `model` and `num_training_steps` only become available during train
config = deepspeed_config_hf.config config = hf_deepspeed_config.config
# Optimizer + Scheduler # Optimizer + Scheduler
# Currently supported combos: # Currently supported combos:
...@@ -485,7 +494,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None): ...@@ -485,7 +494,7 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
optimizer = None optimizer = None
if "optimizer" not in config: if "optimizer" not in config:
if deepspeed_config_hf.is_offload(): if hf_deepspeed_config.is_offload():
raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers") raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
# ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
......
...@@ -26,6 +26,8 @@ from io import StringIO ...@@ -26,6 +26,8 @@ from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Iterator, Union from typing import Iterator, Union
from transformers import logging as transformers_logging
from .file_utils import ( from .file_utils import (
is_datasets_available, is_datasets_available,
is_faiss_available, is_faiss_available,
...@@ -648,6 +650,26 @@ class CaptureLogger: ...@@ -648,6 +650,26 @@ class CaptureLogger:
return f"captured: {self.out}\n" return f"captured: {self.out}\n"
@contextlib.contextmanager
def LoggingLevel(level):
"""
This is a context manager to temporarily change transformers modules logging level to the desired value and have it
restored to the original setting at the end of the scope.
For example ::
with LoggingLevel(logging.INFO):
AutoModel.from_pretrained("gpt2") # calls logger.info() several times
"""
orig_level = transformers_logging.get_verbosity()
try:
transformers_logging.set_verbosity(level)
yield
finally:
transformers_logging.set_verbosity(orig_level)
@contextlib.contextmanager @contextlib.contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239 # adapted from https://stackoverflow.com/a/64789046/9201239
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
......
...@@ -863,9 +863,9 @@ class Trainer: ...@@ -863,9 +863,9 @@ class Trainer:
logger.info("Trial:", trial.params) logger.info("Trial:", trial.params)
if self.args.deepspeed: if self.args.deepspeed:
# Rebuild the deepspeed config to reflect the updated training parameters # Rebuild the deepspeed config to reflect the updated training parameters
from transformers.integrations import DeepSpeedConfigHF from transformers.integrations import HfDeepSpeedConfig
self.args.deepspeed_config_hf = DeepSpeedConfigHF(self.args) self.args.hf_deepspeed_config = HfDeepSpeedConfig(self.args)
def _report_to_hp_search( def _report_to_hp_search(
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float] self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
......
...@@ -671,10 +671,12 @@ class TrainingArguments: ...@@ -671,10 +671,12 @@ class TrainingArguments:
if self.deepspeed: if self.deepspeed:
# - must be run very last in arg parsing, since it will use a lot of these settings. # - must be run very last in arg parsing, since it will use a lot of these settings.
# - must be run before the model is created. # - must be run before the model is created.
from transformers.integrations import DeepSpeedConfigHF from transformers.integrations import HfTrainerDeepSpeedConfig
# will be used later by the Trainer (leave self.deepspeed unmodified in case a user relies on it not to be modified) # will be used later by the Trainer
self.deepspeed_config_hf = DeepSpeedConfigHF(self) # note: leave self.deepspeed unmodified in case a user relies on it not to be modified)
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)
def __repr__(self): def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once # We override the default repr to remove deprecated arguments from the repr. This method should be removed once
......
...@@ -20,13 +20,14 @@ import unittest ...@@ -20,13 +20,14 @@ import unittest
from copy import deepcopy from copy import deepcopy
from parameterized import parameterized from parameterized import parameterized
from transformers import TrainingArguments, is_torch_available from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.integrations import is_deepspeed_available from transformers.integrations import HfDeepSpeedConfig, is_deepspeed_available
from transformers.testing_utils import ( from transformers.testing_utils import (
CaptureLogger, CaptureLogger,
CaptureStderr, CaptureStderr,
ExtendSysPath, ExtendSysPath,
LoggingLevel,
TestCasePlus, TestCasePlus,
execute_subprocess_async, execute_subprocess_async,
get_gpu_count, get_gpu_count,
...@@ -77,6 +78,56 @@ ZERO3 = "zero3" ...@@ -77,6 +78,56 @@ ZERO3 = "zero3"
stages = [ZERO2, ZERO3] stages = [ZERO2, ZERO3]
@require_deepspeed
@require_torch_gpu
class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
"""
Testing non-Trainer DeepSpeed integration
"""
def setUp(self):
super().setUp()
self.dist_env_1_gpu = dict(
MASTER_ADDR="localhost", MASTER_PORT="10999", RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
def test_init_zero3(self):
# test that zero.Init() works correctly under zero3
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
# now remove zero optimization
del ds_config["zero_optimization"]
dschf = HfDeepSpeedConfig(ds_config)
self.assertFalse(dschf.is_zero3())
self.assertFalse(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
AutoModel.from_pretrained(T5_TINY)
self.assertNotIn("Detected DeepSpeed ZeRO-3", cl.out)
@require_deepspeed @require_deepspeed
@require_torch_gpu @require_torch_gpu
class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
...@@ -194,9 +245,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -194,9 +245,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero3_dict) trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero3_dict)
with CaptureLogger(deepspeed_logger) as cs: with CaptureLogger(deepspeed_logger) as cl:
trainer.train() trainer.train()
self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none") self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
# --- These tests need to run on both zero stages --- # # --- These tests need to run on both zero stages --- #
...@@ -230,9 +281,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -230,9 +281,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger. # to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.get_config_dict(stage)) trainer = get_regression_trainer(local_rank=0, deepspeed=self.get_config_dict(stage))
with CaptureLogger(deepspeed_logger) as cs: with CaptureLogger(deepspeed_logger) as cl:
trainer.train() trainer.train()
self.assertIn("DeepSpeed info", cs.out, "expected DeepSpeed logger output but got none") self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
@parameterized.expand(stages) @parameterized.expand(stages)
def test_early_get_last_lr(self, stage): def test_early_get_last_lr(self, stage):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment