Unverified Commit 9a671853 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Experimental support for fairscale ShardedDDP (#9139)

* Experimental stupport for fairscale ShardedDDP

* Add import error if fairscale not available

* Address review comments

* Fix seq2seq trainer
parent 1c1a2ffb
...@@ -20,6 +20,7 @@ from torch.utils.data import DistributedSampler, RandomSampler ...@@ -20,6 +20,7 @@ from torch.utils.data import DistributedSampler, RandomSampler
from transformers import PreTrainedModel, Trainer, logging from transformers import PreTrainedModel, Trainer, logging
from transformers.file_utils import is_torch_tpu_available from transformers.file_utils import is_torch_tpu_available
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import ( from transformers.optimization import (
Adafactor, Adafactor,
...@@ -35,6 +36,10 @@ from transformers.trainer_pt_utils import get_tpu_sampler ...@@ -35,6 +36,10 @@ from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
if is_fairscale_available():
from fairscale.optim import OSS
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
arg_to_scheduler = { arg_to_scheduler = {
...@@ -99,18 +104,25 @@ class Seq2SeqTrainer(Trainer): ...@@ -99,18 +104,25 @@ class Seq2SeqTrainer(Trainer):
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor: if self.args.adafactor:
self.optimizer = Adafactor( optimizer_cls = Adafactor
optimizer_grouped_parameters, optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
lr=self.args.learning_rate,
scale_parameter=False,
relative_step=False,
)
else: else:
self.optimizer = AdamW( optimizer_cls = AdamW
optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
) )
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None: if self.lr_scheduler is None:
self.lr_scheduler = self._get_lr_scheduler(num_training_steps) self.lr_scheduler = self._get_lr_scheduler(num_training_steps)
......
...@@ -92,6 +92,13 @@ try: ...@@ -92,6 +92,13 @@ try:
except ImportError: except ImportError:
_has_mlflow = False _has_mlflow = False
try:
import fairscale # noqa: F401
_has_fairscale = True
except ImportError:
_has_fairscale = False
# No transformer imports above this point # No transformer imports above this point
from .file_utils import is_torch_tpu_available # noqa: E402 from .file_utils import is_torch_tpu_available # noqa: E402
...@@ -128,6 +135,10 @@ def is_mlflow_available(): ...@@ -128,6 +135,10 @@ def is_mlflow_available():
return _has_mlflow return _has_mlflow
def is_fairscale_available():
return _has_fairscale
def hp_params(trial): def hp_params(trial):
if is_optuna_available(): if is_optuna_available():
if isinstance(trial, optuna.Trial): if isinstance(trial, optuna.Trial):
......
...@@ -33,6 +33,7 @@ from .integrations import ( # isort: split ...@@ -33,6 +33,7 @@ from .integrations import ( # isort: split
hp_params, hp_params,
is_azureml_available, is_azureml_available,
is_comet_available, is_comet_available,
is_fairscale_available,
is_mlflow_available, is_mlflow_available,
is_optuna_available, is_optuna_available,
is_ray_available, is_ray_available,
...@@ -153,6 +154,11 @@ if is_azureml_available(): ...@@ -153,6 +154,11 @@ if is_azureml_available():
DEFAULT_CALLBACKS.append(AzureMLCallback) DEFAULT_CALLBACKS.append(AzureMLCallback)
if is_fairscale_available():
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -285,6 +291,16 @@ class Trainer: ...@@ -285,6 +291,16 @@ class Trainer:
if isinstance(eval_dataset, datasets.Dataset): if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation") self._remove_unused_columns(self.eval_dataset, description="evaluation")
# Setup Sharded DDP training
self.sharded_dpp = False
if args.sharded_ddp:
if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
else:
self.sharded_dpp = True
# Mixed precision setup # Mixed precision setup
self.use_apex = False self.use_apex = False
self.use_amp = False self.use_amp = False
...@@ -296,7 +312,7 @@ class Trainer: ...@@ -296,7 +312,7 @@ class Trainer:
if backend == "amp": if backend == "amp":
self.use_amp = True self.use_amp = True
self.scaler = torch.cuda.amp.GradScaler() self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
else: else:
if not is_apex_available(): if not is_apex_available():
raise ImportError( raise ImportError(
...@@ -491,12 +507,21 @@ class Trainer: ...@@ -491,12 +507,21 @@ class Trainer:
"weight_decay": 0.0, "weight_decay": 0.0,
}, },
] ]
self.optimizer = AdamW( if self.sharded_dpp:
optimizer_grouped_parameters, self.optimizer = OSS(
lr=self.args.learning_rate, params=optimizer_grouped_parameters,
betas=(self.args.adam_beta1, self.args.adam_beta2), optim=AdamW,
eps=self.args.adam_epsilon, lr=self.args.learning_rate,
) betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
else:
self.optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.args.learning_rate,
betas=(self.args.adam_beta1, self.args.adam_beta2),
eps=self.args.adam_epsilon,
)
if self.lr_scheduler is None: if self.lr_scheduler is None:
self.lr_scheduler = get_linear_schedule_with_warmup( self.lr_scheduler = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
...@@ -643,7 +668,9 @@ class Trainer: ...@@ -643,7 +668,9 @@ class Trainer:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if self.args.local_rank != -1: if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
elif self.args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel( model = torch.nn.parallel.DistributedDataParallel(
model, model,
device_ids=[self.args.local_rank], device_ids=[self.args.local_rank],
...@@ -654,8 +681,8 @@ class Trainer: ...@@ -654,8 +681,8 @@ class Trainer:
else True else True
), ),
) )
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
# Train! # Train!
if is_torch_tpu_available(): if is_torch_tpu_available():
...@@ -895,6 +922,8 @@ class Trainer: ...@@ -895,6 +922,8 @@ class Trainer:
self.save_model(output_dir) self.save_model(output_dir)
# Save optimizer and scheduler # Save optimizer and scheduler
if self.sharded_dpp:
self.optimizer.consolidate_state_dict()
if is_torch_tpu_available(): if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states") xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
......
...@@ -215,6 +215,9 @@ class TrainingArguments: ...@@ -215,6 +215,9 @@ class TrainingArguments:
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
other choices will force the requested backend. other choices will force the requested backend.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -386,6 +389,10 @@ class TrainingArguments: ...@@ -386,6 +389,10 @@ class TrainingArguments:
default="auto", default="auto",
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
) )
sharded_ddp: bool = field(
default=False,
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
)
def __post_init__(self): def __post_init__(self):
if self.disable_tqdm is None: if self.disable_tqdm is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment