Unverified Commit c6d66484 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[DeepSpeed] ZeRO Stage 3 (#10753)



* synced gpus

* fix

* fix

* need to use t5-small for quality tests

* notes

* complete merge

* fix a disappearing std stream problem

* start zero3 tests

* wip

* tune params

* sorting out the pre-trained model loading

* reworking generate loop wip

* wip

* style

* fix tests

* split the tests

* refactor tests

* wip

* parameterized

* fix

* workout the resume from non-ds checkpoint pass + test

* cleanup

* remove no longer needed code

* split getter/setter functions

* complete the docs

* suggestions

* gpus and their compute capabilities link

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* style

* remove invalid paramgd

* automatically configure zero3 params that rely on hidden size

* make _get_resized_embeddings zero3-aware

* add test exercising resize_token_embeddings()

* add docstring
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent acc851e1
This diff is collapsed.
......@@ -3,7 +3,7 @@
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 32,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
......
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 3,
"cpu_offload": true,
"cpu_offload_params": true,
"cpu_offload_use_pin_memory" : true,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e14,
"reduce_bucket_size": 0,
"stage3_prefetch_bucket_size": 0,
"stage3_param_persistence_threshold": 0,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_fp16_weights_on_model_save": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 3e-5,
"betas": [0.8, 0.999],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 500
}
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
This diff is collapsed.
This diff is collapsed.
......@@ -19,6 +19,7 @@ import io
import json
import numbers
import os
import sys
import tempfile
from copy import deepcopy
from pathlib import Path
......@@ -268,7 +269,77 @@ def rewrite_logs(d):
return new_d
def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
_is_deepspeed_zero3_enabled = None
def is_deepspeed_zero3_enabled():
"""
This function answers to the question of whether DeepSpeed is going to be used and run using ZeRO Stage 3.
It includes an auto-discovery method, see comments in the code for details.
Returns: ``True`` if either it was explicitly enabled via ``deepspeed_zero3_enable(True)`` or the auto-detector was
able to derive that the ``Trainer`` will be running via DeepSpeed ZeRO stage 3.
"""
global _is_deepspeed_zero3_enabled
if _is_deepspeed_zero3_enabled is None:
_is_deepspeed_zero3_enabled = False
# Try to auto-discover if we are about to use DeepSpeed with ZeRO3 enabled. This will only
# work for scripts using cli to pass --deepspeed ds_config.json. If cmd args aren't used,
# then to get the model efficiently loaded across multiple-gpus one has to explicitly call
# is_deepspeed_zero3_enabled(True) **before** instantiating a model object
if "--deepspeed" in sys.argv:
idx = sys.argv.index("--deepspeed")
ds_config = sys.argv[idx + 1]
if not os.path.exists(ds_config):
raise ValueError("--deepspeed requires a valid path to a config file")
config = deepspeed_parse_config(ds_config)
if (
"zero_optimization" in config
and "stage" in config["zero_optimization"]
and config["zero_optimization"]["stage"] == 3
):
_is_deepspeed_zero3_enabled = True
return _is_deepspeed_zero3_enabled
def deepspeed_zero3_enable(enable=True):
"""
``is_deepspeed_zero3_enabled()`` tries to derive automatically if DeepSpeed ZeRO 3 is going to be used by looking
at ``sys.argv`` which may or may contain information about where to find the DeepSpeed config if any.
This function allows for explicit enabling/disabling of this global flag.
Args:
enable: if set to ``True`` will make ``is_deepspeed_zero3_enabled()`` return ``True``
"""
global _is_deepspeed_zero3_enabled
_is_deepspeed_zero3_enabled = enable
def deepspeed_parse_config(ds_config):
"""
If ``ds_config`` isn't already a dict, read it from the config file.
If it's already a dict, return a copy of it, so that we can freely modify it.
"""
require_version("deepspeed>0.3.13")
if isinstance(ds_config, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since some config params must be not set by users
config = deepcopy(ds_config)
elif isinstance(ds_config, str):
with io.open(ds_config, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
return config
def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
"""
Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
......@@ -284,21 +355,10 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
"""
import deepspeed
require_version("deepspeed>0.3.12")
args = trainer.args
ds_config_file = args.deepspeed
model = trainer.model
if isinstance(args.deepspeed, dict):
# Don't modify user's data should they want to reuse it (e.g. in tests), because once we
# modified it, it will not be accepted here again, since some config params must be not set by users
config = deepcopy(args.deepspeed)
elif isinstance(args.deepspeed, str):
with io.open(ds_config_file, "r", encoding="utf-8") as f:
config = json.load(f)
else:
raise ValueError("expecting either a path to a config file or a pre-populated dict")
config = deepspeed_parse_config(args.deepspeed)
# The following code translates relevant trainer's cl args into the DS config
......@@ -324,9 +384,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
config["gradient_accumulation_steps"] = args.gradient_accumulation_steps
if "gradient_clipping" in config:
logger.info(
f"Keeping the `gradient_clipping` config from {ds_config_file} intact, ignoring any gradient clipping-specific cl args"
)
logger.info("Keeping the `gradient_clipping` config intact, ignoring any gradient clipping-specific cl args")
else: # override only if the ds config doesn't already have this section
config["gradient_clipping"] = args.max_grad_norm
......@@ -336,6 +394,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
# 2. HF scheduler + HF optimizer: Yes
# 3. DS scheduler + HF optimizer: Yes
# 4. HF scheduler + DS optimizer: No
#
# Unless Offload is enabled in which case it's:
# 1. DS scheduler + DS optimizer: Yes
# 2. HF scheduler + HF optimizer: No
......@@ -344,7 +403,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
optimizer = None
if "optimizer" in config:
logger.info(f"Updating the `scheduler` config from {ds_config_file} with other command line arguments")
logger.info("Updating the `scheduler` config with other command line arguments")
# to avoid inconsistent values of lr and warm up steps the command line args override config
params = dict(
......@@ -384,7 +443,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
# WarmupDecayLR| linear | get_linear_schedule_with_warmup |
lr_scheduler = None
if "scheduler" in config:
logger.info(f"Updating the `scheduler` config from {ds_config_file} with other command line arguments")
logger.info("Updating the `scheduler` config with other command line arguments")
# the user won't easily know the correct num_training_steps should they use WarmupDecayLR,
# so let's set it to the correct value
if config["scheduler"]["type"] == "WarmupDecayLR":
......@@ -417,9 +476,7 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
# - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
if trainer.fp16_backend == "apex":
if "amp" in config:
logger.info(
f"Keeping the `amp` config from {ds_config_file} intact, ignoring any amp-specific cl args"
)
logger.info("Keeping the `amp` config intact, ignoring any amp-specific cl args")
else:
config["amp"] = {
"enabled": True,
......@@ -427,19 +484,33 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
}
elif trainer.fp16_backend == "amp":
if "fp16" in config:
logger.info(
f"Keeping the `fp16` config from {ds_config_file} intact, ignoring any fp16-specific cl args"
)
logger.info("Keeping the `fp16` config intact, ignoring any fp16-specific cl args")
else:
config["fp16"] = {
"enabled": True,
}
# zero
if "zero_optimization" in config:
zero = config["zero_optimization"]
# now we know for sure if zero3 is enabled
deepspeed_zero3_enable(zero.get("stage") == 3)
# automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size
if zero.get("reduce_bucket_size") == 0:
zero["reduce_bucket_size"] = hidden_size * hidden_size
if zero.get("stage3_prefetch_bucket_size") == 0:
zero["stage3_prefetch_bucket_size"] = 0.9 * hidden_size * hidden_size
if zero.get("stage3_param_persistence_threshold") == 0:
zero["stage3_param_persistence_threshold"] = 10 * hidden_size
# keep for quick debug:
# from pprint import pprint; pprint(config)
# init that takes part of the config via `args`, and the bulk of it via `config_params`
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
model_parameters=model_parameters,
......@@ -448,7 +519,17 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
lr_scheduler=lr_scheduler,
)
if resume_from_checkpoint is not None: # and os.path.isdir(resume_from_checkpoint):
if resume_from_checkpoint is not None:
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
# path contains what looks like a deepspeed checkpoint
import glob
deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))
if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {resume_from_checkpoint}")
# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = model.load_checkpoint(
......@@ -456,6 +537,8 @@ def init_deepspeed(trainer, num_training_steps, resume_from_checkpoint=None):
)
if load_path is None:
raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
else:
logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
return model, optimizer, lr_scheduler
......
......@@ -41,6 +41,7 @@ from .file_utils import (
replace_return_docstrings,
)
from .generation_utils import GenerationMixin
from .integrations import is_deepspeed_zero3_enabled
from .utils import logging
......@@ -660,7 +661,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if new_num_tokens is None:
return old_embeddings
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
if old_num_tokens == new_num_tokens:
return old_embeddings
......@@ -677,8 +685,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self._init_weights(new_embeddings)
# Copy token embeddings from the previous weights
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
# numbers of tokens to copy
n = min(old_num_tokens, new_num_tokens)
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
return new_embeddings
......@@ -1056,6 +1073,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
config.name_or_path = pretrained_model_name_or_path
# Instantiate model.
if is_deepspeed_zero3_enabled():
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first
with deepspeed.zero.Init():
model = cls(config, *model_args, **model_kwargs)
else:
model = cls(config, *model_args, **model_kwargs)
if state_dict is None and not from_tf:
......@@ -1114,15 +1140,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# so we need to apply the function recursively.
def load(module: nn.Module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
args = (state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if is_deepspeed_zero3_enabled():
import deepspeed
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
......
......@@ -17,7 +17,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
"""
import collections
import gc
import inspect
import math
import os
......@@ -41,7 +40,8 @@ from .integrations import ( # isort: split
is_ray_tune_available,
run_hp_search_optuna,
run_hp_search_ray,
init_deepspeed,
deepspeed_init,
is_deepspeed_zero3_enabled,
)
import numpy as np
......@@ -921,7 +921,7 @@ class Trainer:
logger.info(f"Loading model from {resume_from_checkpoint}).")
if self.deepspeed:
# will be resumed in init_deepspeed
# will be resumed in deepspeed_init
pass
elif isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(resume_from_checkpoint)
......@@ -965,12 +965,12 @@ class Trainer:
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(
deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
)
self.model = model.module
self.model_wrapped = model
self.deepspeed = model # DeepSpeedEngine object
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
elif not delay_optimizer_creation:
......@@ -1227,18 +1227,6 @@ class Trainer:
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
if self.deepspeed:
# free up any memory that might be useful for eval
self.deepspeed = None
self.optimizer = None
self.lr_scheduler = None
self.model_wrapped = self.model
gc.collect() # force memory release
# to restore normal behavior outside of train replay the place_model_on_device logic w/o deepspeed
self.place_model_on_device = self.args.place_model_on_device
if self.is_model_parallel:
self.place_model_on_device = False
self.is_in_train = False
self._memory_tracker.stop_and_update_metrics(metrics)
......@@ -1293,6 +1281,8 @@ class Trainer:
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir)
if self.deepspeed:
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
# config `stage3_gather_fp16_weights_on_model_save` is True
self.deepspeed.save_checkpoint(output_dir)
# Save optimizer and scheduler
......@@ -1351,7 +1341,7 @@ class Trainer:
return
if self.deepspeed:
# deepspeed loads optimizer/lr_scheduler together with the model in init_deepspeed
# deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
return
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
......@@ -1597,6 +1587,10 @@ class Trainer:
Will only save from the main process.
"""
if output_dir is None:
output_dir = self.args.output_dir
if is_torch_tpu_available():
self._save_tpu(output_dir)
elif is_sagemaker_mp_enabled():
......@@ -1608,8 +1602,31 @@ class Trainer:
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero():
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.is_world_process_zero():
self._save(output_dir)
if is_deepspeed_zero3_enabled():
# It's too complicated to try to override different places where the weights dump gets
# saved, so since under zero3 the file is bogus, simply delete it. The user should
# either user deepspeed checkpoint to resume or to recover full weights use
# zero_to_fp32.py stored in the checkpoint.
if self.is_world_process_zero():
file = os.path.join(output_dir, WEIGHTS_NAME)
if os.path.isfile(file):
# logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
os.remove(file)
# now save the real model if stage3_gather_fp16_weights_on_model_save=True
# if false it will not be saved.
# This must be called on all ranks
self.deepspeed.save_fp16_model(output_dir, WEIGHTS_NAME)
elif self.is_world_process_zero():
self._save(output_dir)
......@@ -1848,10 +1865,20 @@ class Trainer:
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
)
if self.args.deepspeed and not self.args.do_train:
# no harm, but flagging to the user that deepspeed config is ignored for eval
# flagging only for when --do_train wasn't passed as only then it's redundant
logger.info("Detected the deepspeed argument but it will not be used for evaluation")
# if eval is called w/o train init deepspeed here
if self.args.deepspeed and not self.deepspeed:
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
# from the checkpoint eventually
deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None)
self.model = deepspeed_engine.module
self.model_wrapped = deepspeed_engine
self.deepspeed = deepspeed_engine
# XXX: we don't need optim/sched for inference, but this needs to be sorted out, since
# for example the Z3-optimizer is a must for zero3 to work even for inference - what we
# don't need is the deepspeed basic optimizer which is self.optimizer.optimizer
deepspeed_engine.optimizer.optimizer = None
deepspeed_engine.lr_scheduler = None
model = self._wrap_model(self.model, training=False)
......
......@@ -19,6 +19,7 @@ from packaging import version
from torch import nn
from torch.utils.data.dataset import Dataset
from .integrations import is_deepspeed_zero3_enabled
from .trainer import Trainer
from .trainer_utils import PredictionOutput
from .utils import logging
......@@ -156,9 +157,11 @@ class Seq2SeqTrainer(Trainer):
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
gen_kwargs = {
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
}
generated_tokens = self.model.generate(
......
......@@ -132,6 +132,7 @@ class RegressionModelConfig(PretrainedConfig):
self.a = a
self.b = b
self.double_output = double_output
self.hidden_size = 1
if is_torch_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment