Unverified Commit 27597fea authored by statelesshz's avatar statelesshz Committed by GitHub
Browse files

remove SharedDDP as it is deprecated (#25702)



* remove SharedDDP as it was drepracated

* apply review suggestion

* make style

* Oops,forgot to remove the compute_loss context manager in Seq2SeqTrainer.

* remove the unnecessary conditional statement

* keep the logic of IPEX

* clean code

* mix precision setup & make fixup

---------
Co-authored-by: default avatarstatelesshz <jihuazhong1@huawei.com>
parent e840aa67
...@@ -19,7 +19,6 @@ from torch import nn ...@@ -19,7 +19,6 @@ from torch import nn
from torch.utils.data import DistributedSampler, RandomSampler from torch.utils.data import DistributedSampler, RandomSampler
from transformers import PreTrainedModel, Trainer, logging from transformers import PreTrainedModel, Trainer, logging
from transformers.integrations import is_fairscale_available
from transformers.models.fsmt.configuration_fsmt import FSMTConfig from transformers.models.fsmt.configuration_fsmt import FSMTConfig
from transformers.optimization import ( from transformers.optimization import (
Adafactor, Adafactor,
...@@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode ...@@ -36,10 +35,6 @@ from transformers.training_args import ParallelMode
from transformers.utils import is_torch_tpu_available from transformers.utils import is_torch_tpu_available
if is_fairscale_available():
from fairscale.optim import OSS
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
arg_to_scheduler = { arg_to_scheduler = {
...@@ -118,13 +113,6 @@ class Seq2SeqTrainer(Trainer): ...@@ -118,13 +113,6 @@ class Seq2SeqTrainer(Trainer):
"eps": self.args.adam_epsilon, "eps": self.args.adam_epsilon,
} }
optimizer_kwargs["lr"] = self.args.learning_rate optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.lr_scheduler is None: if self.lr_scheduler is None:
......
...@@ -109,7 +109,6 @@ _deps = [ ...@@ -109,7 +109,6 @@ _deps = [
"diffusers", "diffusers",
"dill<0.3.5", "dill<0.3.5",
"evaluate>=0.2.0", "evaluate>=0.2.0",
"fairscale>0.3",
"faiss-cpu", "faiss-cpu",
"fastapi", "fastapi",
"filelock", "filelock",
...@@ -275,7 +274,6 @@ extras["modelcreation"] = deps_list("cookiecutter") ...@@ -275,7 +274,6 @@ extras["modelcreation"] = deps_list("cookiecutter")
extras["sagemaker"] = deps_list("sagemaker") extras["sagemaker"] = deps_list("sagemaker")
extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"] extras["deepspeed"] = deps_list("deepspeed") + extras["accelerate"]
extras["fairscale"] = deps_list("fairscale")
extras["optuna"] = deps_list("optuna") extras["optuna"] = deps_list("optuna")
extras["ray"] = deps_list("ray[tune]") extras["ray"] = deps_list("ray[tune]")
extras["sigopt"] = deps_list("sigopt") extras["sigopt"] = deps_list("sigopt")
......
...@@ -16,7 +16,6 @@ deps = { ...@@ -16,7 +16,6 @@ deps = {
"diffusers": "diffusers", "diffusers": "diffusers",
"dill": "dill<0.3.5", "dill": "dill<0.3.5",
"evaluate": "evaluate>=0.2.0", "evaluate": "evaluate>=0.2.0",
"fairscale": "fairscale>0.3",
"faiss-cpu": "faiss-cpu", "faiss-cpu": "faiss-cpu",
"fastapi": "fastapi", "fastapi": "fastapi",
"filelock": "filelock", "filelock": "filelock",
......
...@@ -57,7 +57,6 @@ _import_structure = { ...@@ -57,7 +57,6 @@ _import_structure = {
"is_codecarbon_available", "is_codecarbon_available",
"is_comet_available", "is_comet_available",
"is_dagshub_available", "is_dagshub_available",
"is_fairscale_available",
"is_flyte_deck_standard_available", "is_flyte_deck_standard_available",
"is_flytekit_available", "is_flytekit_available",
"is_mlflow_available", "is_mlflow_available",
...@@ -118,7 +117,6 @@ if TYPE_CHECKING: ...@@ -118,7 +117,6 @@ if TYPE_CHECKING:
is_codecarbon_available, is_codecarbon_available,
is_comet_available, is_comet_available,
is_dagshub_available, is_dagshub_available,
is_fairscale_available,
is_flyte_deck_standard_available, is_flyte_deck_standard_available,
is_flytekit_available, is_flytekit_available,
is_mlflow_available, is_mlflow_available,
......
...@@ -134,10 +134,6 @@ def is_dagshub_available(): ...@@ -134,10 +134,6 @@ def is_dagshub_available():
return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")] return None not in [importlib.util.find_spec("dagshub"), importlib.util.find_spec("mlflow")]
def is_fairscale_available():
return importlib.util.find_spec("fairscale") is not None
def is_neptune_available(): def is_neptune_available():
return _has_neptune return _has_neptune
......
...@@ -42,7 +42,6 @@ from transformers import logging as transformers_logging ...@@ -42,7 +42,6 @@ from transformers import logging as transformers_logging
from .integrations import ( from .integrations import (
is_clearml_available, is_clearml_available,
is_fairscale_available,
is_optuna_available, is_optuna_available,
is_ray_available, is_ray_available,
is_sigopt_available, is_sigopt_available,
...@@ -871,13 +870,6 @@ def require_deepspeed(test_case): ...@@ -871,13 +870,6 @@ def require_deepspeed(test_case):
return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
def require_fairscale(test_case):
"""
Decorator marking a test that requires fairscale
"""
return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
def require_apex(test_case): def require_apex(test_case):
""" """
Decorator marking a test that requires apex Decorator marking a test that requires apex
......
...@@ -40,7 +40,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un ...@@ -40,7 +40,6 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
from .integrations import ( from .integrations import (
get_reporting_integration_callbacks, get_reporting_integration_callbacks,
hp_params, hp_params,
is_fairscale_available,
) )
# isort: on # isort: on
...@@ -58,7 +57,6 @@ from . import __version__ ...@@ -58,7 +57,6 @@ from . import __version__
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow from .debug_utils import DebugOption, DebugUnderflowOverflow
from .dependency_versions_check import dep_version_check
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
...@@ -107,7 +105,6 @@ from .trainer_utils import ( ...@@ -107,7 +105,6 @@ from .trainer_utils import (
IntervalStrategy, IntervalStrategy,
PredictionOutput, PredictionOutput,
RemoveColumnsCollator, RemoveColumnsCollator,
ShardedDDPOption,
TrainerMemoryTracker, TrainerMemoryTracker,
TrainOutput, TrainOutput,
default_compute_objective, default_compute_objective,
...@@ -171,15 +168,6 @@ if is_torch_tpu_available(check_device=False): ...@@ -171,15 +168,6 @@ if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
if is_fairscale_available():
dep_version_check("fairscale")
import fairscale
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.nn.wrap import auto_wrap
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
...@@ -420,33 +408,6 @@ class Trainer: ...@@ -420,33 +408,6 @@ class Trainer:
" model, please make sure that you have installed `bitsandbytes>=0.37.0`. " " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
) )
# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if self.is_deepspeed_enabled:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)
if len(args.fsdp) > 0:
raise ValueError(
"Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
)
if args.parallel_mode != ParallelMode.DISTRIBUTED:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
self.fsdp = None self.fsdp = None
if len(args.fsdp) > 0: if len(args.fsdp) > 0:
if self.is_deepspeed_enabled: if self.is_deepspeed_enabled:
...@@ -488,14 +449,12 @@ class Trainer: ...@@ -488,14 +449,12 @@ class Trainer:
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
# and we only use deepspeed for training at the moment # and we only use deepspeed for training at the moment
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
# 4. Sharded DDP - same as MP # 4. FSDP - same as MP
# 5. FSDP - same as MP
self.place_model_on_device = args.place_model_on_device self.place_model_on_device = args.place_model_on_device
if ( if (
self.is_model_parallel self.is_model_parallel
or self.is_deepspeed_enabled or self.is_deepspeed_enabled
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
or (self.fsdp is not None) or (self.fsdp is not None)
or self.is_fsdp_enabled or self.is_fsdp_enabled
): ):
...@@ -545,11 +504,11 @@ class Trainer: ...@@ -545,11 +504,11 @@ class Trainer:
" `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
" `model.to(xm.xla_device())` is performed before the optimizer creation in your script." " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
) )
if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( if (self.is_deepspeed_enabled or (self.fsdp is not None)) and (
self.optimizer is not None or self.lr_scheduler is not None self.optimizer is not None or self.lr_scheduler is not None
): ):
raise RuntimeError( raise RuntimeError(
"Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." "Passing `optimizers` is not allowed if Deepspeed or PyTorch FSDP is enabled."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
) )
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
...@@ -592,7 +551,6 @@ class Trainer: ...@@ -592,7 +551,6 @@ class Trainer:
# Mixed precision setup # Mixed precision setup
self.use_apex = False self.use_apex = False
self.use_cuda_amp = False
self.use_cpu_amp = False self.use_cpu_amp = False
# Mixed precision setup for SageMaker Model Parallel # Mixed precision setup for SageMaker Model Parallel
...@@ -617,31 +575,17 @@ class Trainer: ...@@ -617,31 +575,17 @@ class Trainer:
f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
"but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
) )
if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
if args.half_precision_backend == "auto":
if args.device == torch.device("cpu"): if args.device == torch.device("cpu"):
if args.fp16: if args.fp16:
raise ValueError("Tried to use `fp16` but it is not supported on cpu") raise ValueError("Tried to use `fp16` but it is not supported on cpu")
else: else:
args.half_precision_backend = "cpu_amp" args.half_precision_backend = "cpu_amp"
else:
args.half_precision_backend = "cuda_amp"
logger.info(f"Using {args.half_precision_backend} half precision backend") logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision # deepspeed and SageMaker Model Parallel manage their own half precision
if self.sharded_ddp is not None: if args.half_precision_backend == "cpu_amp":
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
# bf16 does not need grad scaling
self.do_grad_scaling = self.amp_dtype == torch.float16
if self.do_grad_scaling:
self.scaler = ShardedGradScaler()
elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16 self.amp_dtype = torch.bfloat16
elif args.half_precision_backend == "apex": elif args.half_precision_backend == "apex":
...@@ -652,18 +596,6 @@ class Trainer: ...@@ -652,18 +596,6 @@ class Trainer:
) )
self.use_apex = True self.use_apex = True
# FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
if (
is_sagemaker_mp_enabled()
and self.use_cuda_amp
and args.max_grad_norm is not None
and args.max_grad_norm > 0
):
raise ValueError(
"SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
"along 'max_grad_norm': 0 in your hyperparameters."
)
# Label smoothing # Label smoothing
if self.args.label_smoothing_factor != 0: if self.args.label_smoothing_factor != 0:
self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
...@@ -994,13 +926,6 @@ class Trainer: ...@@ -994,13 +926,6 @@ class Trainer:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit": if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes import bitsandbytes
...@@ -1333,7 +1258,6 @@ class Trainer: ...@@ -1333,7 +1258,6 @@ class Trainer:
jit_model(**example_batch) jit_model(**example_batch)
model = jit_model model = jit_model
self.use_cpu_amp = False self.use_cpu_amp = False
self.use_cuda_amp = False
except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
logger.warning(f"failed to use PyTorch jit mode due to: {e}.") logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
...@@ -1396,25 +1320,8 @@ class Trainer: ...@@ -1396,25 +1320,8 @@ class Trainer:
return model return model
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16 or self.args.bf16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# XXX: Breaking the self.model convention but I see no way around it for now.
if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
model = auto_wrap(model)
self.model = model = FullyShardedDDP(
model,
mixed_precision=mixed_precision,
reshard_after_forward=zero_3,
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP # Distributed training using PyTorch FSDP
elif self.fsdp is not None and self.args.fsdp_config["xla"]: if self.fsdp is not None and self.args.fsdp_config["xla"]:
try: try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp import checkpoint_module
...@@ -1669,13 +1576,7 @@ class Trainer: ...@@ -1669,13 +1576,7 @@ class Trainer:
else: else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa debug_overflow = DebugUnderflowOverflow(self.model) # noqa
delay_optimizer_creation = ( delay_optimizer_creation = is_sagemaker_mp_enabled() or self.fsdp is not None or self.is_fsdp_enabled
self.sharded_ddp is not None
and self.sharded_ddp != ShardedDDPOption.SIMPLE
or is_sagemaker_mp_enabled()
or self.fsdp is not None
or self.is_fsdp_enabled
)
# We need to reset the scheduler, as its parameters may be different on subsequent calls # We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler: if self._created_lr_scheduler:
...@@ -1716,7 +1617,7 @@ class Trainer: ...@@ -1716,7 +1617,7 @@ class Trainer:
# as the model is wrapped, don't use `accelerator.prepare` # as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as # this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False use_accelerator_prepare = True if model is self.model else False
if delay_optimizer_creation: if delay_optimizer_creation:
...@@ -1932,14 +1833,6 @@ class Trainer: ...@@ -1932,14 +1833,6 @@ class Trainer:
if args.max_grad_norm is not None and args.max_grad_norm > 0: if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping # deepspeed does its own clipping
if self.do_grad_scaling:
# Reduce gradients first for XLA
if is_torch_tpu_available():
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
# AMP: gradients need unscaling
self.scaler.unscale_(self.optimizer)
if is_sagemaker_mp_enabled() and args.fp16: if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm) self.optimizer.clip_master_grads(args.max_grad_norm)
elif hasattr(self.optimizer, "clip_grad_norm"): elif hasattr(self.optimizer, "clip_grad_norm"):
...@@ -1961,24 +1854,8 @@ class Trainer: ...@@ -1961,24 +1854,8 @@ class Trainer:
) )
# Optimizer step # Optimizer step
optimizer_was_run = True
if is_torch_tpu_available():
if self.do_grad_scaling:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# tpu-comment: accelerate wrapped optimizers call xm.optimizer_step
self.optimizer.step()
elif self.do_grad_scaling:
scale_before = self.scaler.get_scale()
self.scaler.step(self.optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
optimizer_was_run = scale_before <= scale_after
else:
self.optimizer.step() self.optimizer.step()
optimizer_was_run = not self.accelerator.optimizer_step_was_skipped optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
if optimizer_was_run: if optimizer_was_run:
# Delay optimizer scheduling until metrics are generated # Delay optimizer scheduling until metrics are generated
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
...@@ -2408,9 +2285,6 @@ class Trainer: ...@@ -2408,9 +2285,6 @@ class Trainer:
self.model_wrapped.save_checkpoint(output_dir) self.model_wrapped.save_checkpoint(output_dir)
# Save optimizer and scheduler # Save optimizer and scheduler
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()
if self.fsdp or self.is_fsdp_enabled: if self.fsdp or self.is_fsdp_enabled:
if self.is_fsdp_enabled: if self.is_fsdp_enabled:
save_fsdp_optimizer( save_fsdp_optimizer(
...@@ -2455,8 +2329,6 @@ class Trainer: ...@@ -2455,8 +2329,6 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
# Determine the new best metric / best model checkpoint # Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None: if metrics is not None and self.args.metric_for_best_model is not None:
...@@ -2600,8 +2472,6 @@ class Trainer: ...@@ -2600,8 +2472,6 @@ class Trainer:
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
reissue_pt_warnings(caught_warnings) reissue_pt_warnings(caught_warnings)
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
def hyperparameter_search( def hyperparameter_search(
self, self,
...@@ -2744,12 +2614,8 @@ class Trainer: ...@@ -2744,12 +2614,8 @@ class Trainer:
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation. arguments, depending on the situation.
""" """
if self.use_cuda_amp or self.use_cpu_amp: if self.use_cpu_amp:
ctx_manager = ( ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
)
else: else:
ctx_manager = contextlib.nullcontext() ctx_manager = contextlib.nullcontext()
...@@ -2786,9 +2652,7 @@ class Trainer: ...@@ -2786,9 +2652,7 @@ class Trainer:
if self.args.n_gpu > 1: if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.do_grad_scaling: if self.use_apex:
self.scaler.scale(loss).backward()
elif self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss: with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
else: else:
...@@ -2872,12 +2736,7 @@ class Trainer: ...@@ -2872,12 +2736,7 @@ class Trainer:
if IS_SAGEMAKER_MP_POST_1_10: if IS_SAGEMAKER_MP_POST_1_10:
# 'user_content.pt' indicates model state_dict saved with smp >= 1.10 # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
Path(os.path.join(output_dir, "user_content.pt")).touch() Path(os.path.join(output_dir, "user_content.pt")).touch()
elif ( elif self.fsdp is not None or self.is_fsdp_enabled:
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
or self.is_fsdp_enabled
):
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {} state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
if self.args.should_save: if self.args.should_save:
self._save(output_dir, state_dict=state_dict) self._save(output_dir, state_dict=state_dict)
......
...@@ -266,7 +266,6 @@ class Seq2SeqTrainer(Trainer): ...@@ -266,7 +266,6 @@ class Seq2SeqTrainer(Trainer):
has_labels = "labels" in inputs has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well
# Priority (handled in generate): # Priority (handled in generate):
# non-`None` gen_kwargs > model.generation_config > default GenerationConfig() # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
......
...@@ -651,14 +651,6 @@ def number_of_arguments(func): ...@@ -651,14 +651,6 @@ def number_of_arguments(func):
return len(inspect.signature(func).parameters) return len(inspect.signature(func).parameters)
class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero_dp_2"
ZERO_DP_3 = "zero_dp_3"
OFFLOAD = "offload"
AUTO_WRAP = "auto_wrap"
def find_executable_batch_size( def find_executable_batch_size(
function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False function: callable = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
): ):
......
...@@ -34,7 +34,6 @@ from .trainer_utils import ( ...@@ -34,7 +34,6 @@ from .trainer_utils import (
HubStrategy, HubStrategy,
IntervalStrategy, IntervalStrategy,
SchedulerType, SchedulerType,
ShardedDDPOption,
) )
from .utils import ( from .utils import (
ExplicitEnum, ExplicitEnum,
...@@ -328,9 +327,9 @@ class TrainingArguments: ...@@ -328,9 +327,9 @@ class TrainingArguments:
fp16_backend (`str`, *optional*, defaults to `"auto"`): fp16_backend (`str`, *optional*, defaults to `"auto"`):
This argument is deprecated. Use `half_precision_backend` instead. This argument is deprecated. Use `half_precision_backend` instead.
half_precision_backend (`str`, *optional*, defaults to `"auto"`): half_precision_backend (`str`, *optional*, defaults to `"auto"`):
The backend to use for mixed precision training. Must be one of `"auto", "cuda_amp", "apex", "cpu_amp"`. The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
`"auto"` will use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
will force the requested backend. requested backend.
bf16_full_eval (`bool`, *optional*, defaults to `False`): bf16_full_eval (`bool`, *optional*, defaults to `False`):
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
metric values. This is an experimental API and it may change. metric values. This is an experimental API and it may change.
...@@ -410,21 +409,6 @@ class TrainingArguments: ...@@ -410,21 +409,6 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
can take a long time) but will not yield the same results as the interrupted training would have. can take a long time) but will not yield the same results as the interrupted training would have.
sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`):
Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed
training only). This is an experimental feature.
A list of options along the following:
- `"simple"`: to use first instance of sharded DDP released by fairscale (`ShardedDDP`) similar to ZeRO-2.
- `"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
Zero-2 mode (with `reshard_after_forward=False`).
- `"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale (`FullyShardedDDP`) in
Zero-3 mode (with `reshard_after_forward=True`).
- `"offload"`: to add ZeRO-offload (only compatible with `"zero_dp_2"` and `"zero_dp_3"`).
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty
list for `False` and `["simple"]` for `True`.
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`): fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
Use PyTorch Distributed Parallel Training (in distributed training only). Use PyTorch Distributed Parallel Training (in distributed training only).
...@@ -877,7 +861,7 @@ class TrainingArguments: ...@@ -877,7 +861,7 @@ class TrainingArguments:
default="auto", default="auto",
metadata={ metadata={
"help": "The backend to be used for half precision.", "help": "The backend to be used for half precision.",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"], "choices": ["auto", "apex", "cpu_amp"],
}, },
) )
bf16_full_eval: bool = field( bf16_full_eval: bool = field(
...@@ -996,17 +980,6 @@ class TrainingArguments: ...@@ -996,17 +980,6 @@ class TrainingArguments:
) )
}, },
) )
sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field(
default="",
metadata={
"help": (
"Whether or not to use sharded DDP training (in distributed training only). The base option should be"
" `simple`, `zero_dp_2` or `zero_dp_3` and you can add CPU-offload to `zero_dp_2` or `zero_dp_3` like"
" this: zero_dp_2 offload` or `zero_dp_3 offload`. You can add auto-wrap to `zero_dp_2` or `zero_dp_3`"
" with the same syntax: zero_dp_2 auto_wrap` or `zero_dp_3 auto_wrap`."
),
},
)
fsdp: Optional[Union[List[FSDPOption], str]] = field( fsdp: Optional[Union[List[FSDPOption], str]] = field(
default="", default="",
metadata={ metadata={
...@@ -1154,7 +1127,7 @@ class TrainingArguments: ...@@ -1154,7 +1127,7 @@ class TrainingArguments:
default="auto", default="auto",
metadata={ metadata={
"help": "Deprecated. Use half_precision_backend instead", "help": "Deprecated. Use half_precision_backend instead",
"choices": ["auto", "cuda_amp", "apex", "cpu_amp"], "choices": ["auto", "apex", "cpu_amp"],
}, },
) )
push_to_hub_model_id: Optional[str] = field( push_to_hub_model_id: Optional[str] = field(
...@@ -1407,8 +1380,6 @@ class TrainingArguments: ...@@ -1407,8 +1380,6 @@ class TrainingArguments:
" `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use" " `--half_precision_backend apex`: GPU bf16 is not supported by apex. Use"
" `--half_precision_backend cuda_amp` instead" " `--half_precision_backend cuda_amp` instead"
) )
if not (self.sharded_ddp == "" or not self.sharded_ddp):
raise ValueError("sharded_ddp is not supported with bf16")
if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU: if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
if self.evaluation_strategy == IntervalStrategy.NO: if self.evaluation_strategy == IntervalStrategy.NO:
...@@ -1508,7 +1479,7 @@ class TrainingArguments: ...@@ -1508,7 +1479,7 @@ class TrainingArguments:
# no need to assert on else # no need to assert on else
# if training args is specified, it will override the one specified in the accelerate config # if training args is specified, it will override the one specified in the accelerate config
if self.half_precision_backend != "apex" and len(self.sharded_ddp) == 0: if self.half_precision_backend != "apex":
mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no") mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
if self.fp16: if self.fp16:
mixed_precision_dtype = "fp16" mixed_precision_dtype = "fp16"
...@@ -1541,26 +1512,6 @@ class TrainingArguments: ...@@ -1541,26 +1512,6 @@ class TrainingArguments:
" during training" " during training"
) )
if not (self.sharded_ddp == "" or not self.sharded_ddp):
warnings.warn(
"using `sharded_ddp` is deprecated and will be removed in version 4.33"
" of 🤗 Transformers. Use `fsdp` instead",
FutureWarning,
)
if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else ""
if isinstance(self.sharded_ddp, str):
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()]
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]:
raise ValueError(
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or "
'`--sharded_ddp zero_dp_3`. For example, `--sharded_ddp "zero_dp_2 offload"`.'
)
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.SIMPLE in self.sharded_ddp:
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.")
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp:
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.")
if isinstance(self.fsdp, bool): if isinstance(self.fsdp, bool):
self.fsdp = "full_shard" if self.fsdp else "" self.fsdp = "full_shard" if self.fsdp else ""
if isinstance(self.fsdp, str): if isinstance(self.fsdp, str):
......
...@@ -16,7 +16,6 @@ import math ...@@ -16,7 +16,6 @@ import math
import os import os
import re import re
import sys import sys
import unittest
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple
from unittest.mock import patch from unittest.mock import patch
...@@ -32,7 +31,6 @@ from transformers.testing_utils import ( ...@@ -32,7 +31,6 @@ from transformers.testing_utils import (
get_torch_dist_unique_port, get_torch_dist_unique_port,
require_apex, require_apex,
require_bitsandbytes, require_bitsandbytes,
require_fairscale,
require_torch, require_torch,
require_torch_gpu, require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
...@@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus): ...@@ -105,36 +103,6 @@ class TestTrainerExt(TestCasePlus):
def test_run_seq2seq_ddp(self): def test_run_seq2seq_ddp(self):
self.run_seq2seq_quick(distributed=True) self.run_seq2seq_quick(distributed=True)
# test --sharded_ddp w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")
# test --sharded_ddp w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")
# test --sharded_ddp zero_dp_2 w/o --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
# test --sharded_ddp zero_dp_2 w/ --fp16
@unittest.skip("Requires an update of the env running those tests")
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
)
@require_apex @require_apex
@require_torch_gpu @require_torch_gpu
def test_run_seq2seq_apex(self): def test_run_seq2seq_apex(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment