Unverified Commit 2c73b930 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Deepspeed] Assert on mismatches between ds and hf args (#12021)



* wip

* add mismatch validation + test

* renames

* Update docs/source/main_classes/deepspeed.rst
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* renames
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 242ec31a
...@@ -537,7 +537,14 @@ difficult to detect ways. You have been warned. ...@@ -537,7 +537,14 @@ difficult to detect ways. You have been warned.
There are multiple other values that are specific to DeepSpeed-only and those you will have to set manually to suit There are multiple other values that are specific to DeepSpeed-only and those you will have to set manually to suit
your needs. your needs.
In your own programs, you can also use the following approach if you'd like to modify the DeepSpeed config as a master
and configure :class:`~transformers.TrainingArguments` based on that. The steps are:
1. Create or load the DeepSpeed configuration to be used as a master configuration
2. Create the :class:`~transformers.TrainingArguments` object based on these values
Do note that some values, such as :obj:`scheduler.params.total_num_steps` are calculated by
:class:`~transformers.Trainer` during ``train``, but you can of course do the math yourself.
.. _deepspeed-zero: .. _deepspeed-zero:
......
...@@ -20,6 +20,7 @@ import io ...@@ -20,6 +20,7 @@ import io
import json import json
import weakref import weakref
from copy import deepcopy from copy import deepcopy
from functools import partialmethod
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .utils import logging from .utils import logging
...@@ -32,19 +33,6 @@ def is_deepspeed_available(): ...@@ -32,19 +33,6 @@ def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None return importlib.util.find_spec("deepspeed") is not None
def _is_true(config, key):
if config is None:
return False
return bool(config.get(key))
def _set_if_auto(config, key, val):
if config is None:
return
if config.get(key) == "auto":
config[key] = val
class HfDeepSpeedConfig: class HfDeepSpeedConfig:
""" """
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
...@@ -89,7 +77,7 @@ class HfDeepSpeedConfig: ...@@ -89,7 +77,7 @@ class HfDeepSpeedConfig:
self.offload = False self.offload = False
config_zero = config.get("zero_optimization", {}) config_zero = config.get("zero_optimization", {})
if self.is_zero2(): if self.is_zero2():
self.offload = _is_true(config_zero, "cpu_offload") self.offload = self.is_true(config_zero, "cpu_offload")
elif self.is_zero3(): elif self.is_zero3():
offload_devices = ["cpu", "nvme"] offload_devices = ["cpu", "nvme"]
if config_zero.get("offload_optimizer", {}).get("device") in offload_devices: if config_zero.get("offload_optimizer", {}).get("device") in offload_devices:
...@@ -106,6 +94,12 @@ class HfDeepSpeedConfig: ...@@ -106,6 +94,12 @@ class HfDeepSpeedConfig:
def is_offload(self): def is_offload(self):
return self.offload return self.offload
@staticmethod
def is_true(config, key):
if config is None:
return False
return bool(config.get(key))
class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
""" """
...@@ -116,37 +110,67 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -116,37 +110,67 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
def __init__(self, config_file_or_dict): def __init__(self, config_file_or_dict):
super().__init__(config_file_or_dict) super().__init__(config_file_or_dict)
self.mismatches = []
def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
"""
A utility method that massages the config file and can optionally verify that the values match.
1. Replace "auto" values with ``TrainingArguments`` value.
2. If it wasn't "auto" and ``must_match`` is true, then check that DS config matches Trainer
config values and if mismatched add the entry to ``self.mismatched`` - will assert during
``trainer_config_finalize`` for one or more mismatches.
"""
config = self.config
# find the config node of interest if it exists
nodes = ds_key_long.split(".")
ds_key = nodes.pop()
for node in nodes:
config = config.get(node)
if config is None:
return
if config.get(ds_key) == "auto":
config[ds_key] = hf_val
return
if not must_match:
return
ds_val = config.get(ds_key)
if ds_val is not None and ds_val != hf_val:
self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
fill_only = partialmethod(fill_match, must_match=False)
def trainer_config_process(self, args): def trainer_config_process(self, args):
""" """
Adjust the config with ``TrainingArguments`` values. This stage is run during ``TrainingArguments`` object Adjust the config with ``TrainingArguments`` values. This stage is run during ``TrainingArguments`` object
creation. creation.
""" """
config = self.config
# DeepSpeed does: # DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
_set_if_auto(config, "train_micro_batch_size_per_gpu", args.per_device_train_batch_size) self.fill_match(
_set_if_auto(config, "gradient_accumulation_steps", args.gradient_accumulation_steps) "train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size"
_set_if_auto(config, "train_batch_size", train_batch_size) )
_set_if_auto(config, "gradient_clipping", args.max_grad_norm) self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps")
self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)")
config_optim = config.get("optimizer", {}) self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")
if config_optim != {}:
config_optim_params = config_optim.get("params") self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
_set_if_auto(config_optim_params, "lr", args.learning_rate) self.fill_match("optimizer.params.betas", [args.adam_beta1, args.adam_beta2], "adam_beta1+adam_beta2")
_set_if_auto(config_optim_params, "betas", [args.adam_beta1, args.adam_beta2]) self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
_set_if_auto(config_optim_params, "eps", args.adam_epsilon) self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")
_set_if_auto(config_optim_params, "weight_decay", args.weight_decay)
self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
config_sched = config.get("scheduler", {}) self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
if config_sched != {}: self.fill_match("scheduler.params.warmup_num_steps", args.warmup_steps, "warmup_steps")
config_sched_params = config_sched.get("params") # total_num_steps - will get set in trainer_config_finalize
_set_if_auto(config_sched_params, "warmup_min_lr", 0)
_set_if_auto(config_sched_params, "warmup_max_lr", args.learning_rate)
_set_if_auto(config_sched_params, "warmup_num_steps", args.warmup_steps)
# total_num_steps - will get set in trainer_config_finalize
# fp16 # fp16
if args.fp16: if args.fp16:
...@@ -156,14 +180,12 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -156,14 +180,12 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
# amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
# any here unless the user did the work # any here unless the user did the work
config_fp16 = config.get("fp16") self.fill_match("fp16.enabled", fp16_backend == "amp", "fp16+fp16_backend(amp)")
_set_if_auto(config_fp16, "enabled", fp16_backend == "amp")
# apex: delegates amp work to apex (which needs to be available), but it cannot be used with any # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
# ZeRO features # ZeRO features
config_amp = config.get("amp") self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
_set_if_auto(config_amp, "enabled", fp16_backend == "apex") self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
_set_if_auto(config_amp, "opt_level", args.fp16_opt_level)
def trainer_config_finalize(self, args, model, num_training_steps): def trainer_config_finalize(self, args, model, num_training_steps):
""" """
...@@ -171,21 +193,23 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig): ...@@ -171,21 +193,23 @@ class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
Now we we can complete the configuration process. Now we we can complete the configuration process.
""" """
config = self.config
# zero # zero
config_zero = config.get("zero_optimization", {})
if self.is_zero3(): if self.is_zero3():
# automatically assign the optimal config values based on model config # automatically assign the optimal config values based on model config
hidden_size = model.config.hidden_size hidden_size = model.config.hidden_size
_set_if_auto(config_zero, "reduce_bucket_size", hidden_size * hidden_size) self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size) self.fill_only("zero_optimization.stage3_prefetch_bucket_size", 0.9 * hidden_size * hidden_size)
_set_if_auto(config_zero, "stage3_param_persistence_threshold", 10 * hidden_size) self.fill_only("zero_optimization.stage3_param_persistence_threshold", 10 * hidden_size)
# scheduler # scheduler
config_sched = config.get("scheduler", {}) self.fill_match("scheduler.params.total_num_steps", num_training_steps, "num_training_steps (calculated)")
config_sched_params = config_sched.get("params", {})
_set_if_auto(config_sched_params, "total_num_steps", num_training_steps) if len(self.mismatches) > 0:
mismatches = "\n".join(self.mismatches)
raise ValueError(
f"Please correct the following DeepSpeed config values that mismatch TrainingArguments values:\n{mismatches}\n"
"The easiest method is to set these DeepSpeed config values to 'auto'."
)
# keep the config object global to be able to access it anywhere during TrainingArguments life-cycle # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
......
...@@ -205,6 +205,58 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -205,6 +205,58 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# --- These tests are enough to run on one of zero stages --- # # --- These tests are enough to run on one of zero stages --- #
def test_hf_ds_config_mismatch(self):
ds_config = self.get_config_dict(ZERO2)
# Purposefully configure these values to mismatch TrainingArguments values.
# This currently doesn't cover all keys (but it could)
per_device_train_batch_size = 2
ds_config["train_micro_batch_size_per_gpu"] = per_device_train_batch_size + 2
ds_config["train_batch_size"] = 1000
gradient_accumulation_steps = 2
ds_config["gradient_accumulation_steps"] = gradient_accumulation_steps + 2
max_grad_norm = 1.0
ds_config["gradient_clipping"] = max_grad_norm + 0.1
adam_beta1, adam_beta2 = 0.9, 0.99
ds_config["optimizer"]["params"]["betas"] = [adam_beta1 - 0.1, adam_beta2 - 0.1]
fp16 = True
ds_config["fp16"]["enabled"] = not fp16
keys = [
"per_device_train_batch_size",
"train_batch_size",
"gradient_accumulation_steps",
"max_grad_norm",
"betas",
"fp16",
]
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(
local_rank=0,
fp16=fp16,
deepspeed=ds_config,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
max_grad_norm=max_grad_norm,
adam_beta1=adam_beta1,
adam_beta2=adam_beta2,
)
with self.assertRaises(Exception) as context:
trainer.train()
for key in keys:
self.assertTrue(
key in str(context.exception),
f"{key} is not in the exception message:\n{context.exception}",
)
# Test various combos # Test various combos
# 1. DS scheduler + DS optimizer: this is already tested by most other tests # 1. DS scheduler + DS optimizer: this is already tested by most other tests
# 2. HF scheduler + HF optimizer: # 2. HF scheduler + HF optimizer:
...@@ -219,7 +271,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -219,7 +271,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
trainer.train() trainer.train()
new_a = trainer.model.a.item() new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a) self.assertNotEqual(new_a, a)
...@@ -231,7 +283,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -231,7 +283,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer del ds_config_zero2_dict["optimizer"] # force default HF Trainer optimizer
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(a=a, local_rank=0, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(a=a, local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
trainer.train() trainer.train()
new_a = trainer.model.a.item() new_a = trainer.model.a.item()
self.assertNotEqual(new_a, a) self.assertNotEqual(new_a, a)
...@@ -243,7 +295,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -243,7 +295,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler del ds_config_zero2_dict["scheduler"] # force default HF Trainer scheduler
ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False ds_config_zero2_dict["zero_optimization"]["cpu_offload"] = False
ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step ds_config_zero2_dict["fp16"]["initial_scale_power"] = 1 # force optimizer on the first step
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero2_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero2_dict)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
trainer.train() trainer.train()
self.assertTrue( self.assertTrue(
...@@ -261,7 +313,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -261,7 +313,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
ds_config_zero3_dict = self.get_config_dict(ZERO3) ds_config_zero3_dict = self.get_config_dict(ZERO3)
ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config ds_config_zero3_dict["zero_optimization"]["offload_optimizer"] = nvme_config
ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config ds_config_zero3_dict["zero_optimization"]["offload_param"] = nvme_config
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_zero3_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_zero3_dict)
with CaptureLogger(deepspeed_logger) as cl: with CaptureLogger(deepspeed_logger) as cl:
trainer.train() trainer.train()
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none") self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
...@@ -279,7 +331,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -279,7 +331,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
elif stage == "stage3": elif stage == "stage3":
ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu" ds_config_dict["zero_optimization"]["offload_optimizer"]["device"] = "cpu"
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=ds_config_dict) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=ds_config_dict)
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
trainer.train() trainer.train()
self.assertIn( self.assertIn(
...@@ -297,7 +349,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -297,7 +349,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# it's run not as a first test as `sys.stdout` will no longer be the same. So we either have # it's run not as a first test as `sys.stdout` will no longer be the same. So we either have
# to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger. # to reset `deepspeed_logger.handlers[0].setStream(sys.stdout)` or directly capture from the deepspeed_logger.
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(local_rank=0, deepspeed=self.get_config_dict(stage)) trainer = get_regression_trainer(local_rank=0, fp16=True, deepspeed=self.get_config_dict(stage))
with CaptureLogger(deepspeed_logger) as cl: with CaptureLogger(deepspeed_logger) as cl:
trainer.train() trainer.train()
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none") self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
...@@ -317,6 +369,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -317,6 +369,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
b=b, b=b,
local_rank=0, local_rank=0,
train_len=8, train_len=8,
fp16=True,
deepspeed=self.get_config_dict(stage), deepspeed=self.get_config_dict(stage),
per_device_train_batch_size=8, per_device_train_batch_size=8,
logging_steps=1, logging_steps=1,
...@@ -360,6 +413,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -360,6 +413,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
b=b, b=b,
local_rank=0, local_rank=0,
train_len=train_len, train_len=train_len,
fp16=True,
deepspeed=self.get_config_dict(stage), deepspeed=self.get_config_dict(stage),
per_device_train_batch_size=8, per_device_train_batch_size=8,
gradient_accumulation_steps=1, gradient_accumulation_steps=1,
...@@ -377,6 +431,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -377,6 +431,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
b=b, b=b,
local_rank=0, local_rank=0,
train_len=train_len, train_len=train_len,
fp16=True,
deepspeed=self.get_config_dict(stage), deepspeed=self.get_config_dict(stage),
per_device_train_batch_size=4, per_device_train_batch_size=4,
gradient_accumulation_steps=2, gradient_accumulation_steps=2,
...@@ -450,6 +505,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -450,6 +505,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
trainer = get_regression_trainer( trainer = get_regression_trainer(
output_dir=output_dir, output_dir=output_dir,
save_steps=freq, save_steps=freq,
fp16=True,
deepspeed=ds_config_dict, deepspeed=ds_config_dict,
) )
trainer.train() trainer.train()
...@@ -463,7 +519,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -463,7 +519,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
ds_config_dict = self.get_config_dict(stage) ds_config_dict = self.get_config_dict(stage)
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
trainer = get_regression_trainer(output_dir=output_dir, deepspeed=ds_config_dict) trainer = get_regression_trainer(output_dir=output_dir, fp16=True, deepspeed=ds_config_dict)
# 1. fail to find any checkpoint - due a fresh output_dir # 1. fail to find any checkpoint - due a fresh output_dir
with self.assertRaises(Exception) as context: with self.assertRaises(Exception) as context:
...@@ -491,7 +547,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -491,7 +547,9 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
if stage == ZERO3: if stage == ZERO3:
ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True ds_config_dict["zero_optimization"]["stage3_gather_fp16_weights_on_model_save"] = True
kwargs = dict(output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, deepspeed=ds_config_dict) kwargs = dict(
output_dir=output_dir, train_len=128, save_steps=5, learning_rate=0.1, fp16=True, deepspeed=ds_config_dict
)
with mockenv_context(**self.dist_env_1_gpu): with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs) trainer = get_regression_trainer(**kwargs)
...@@ -528,12 +586,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -528,12 +586,12 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
# test that we can switch from zero2 to zero3 in the same process for example # test that we can switch from zero2 to zero3 in the same process for example
# test is_zero, etc. # test is_zero, etc.
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
kwargs = dict(output_dir=output_dir, train_len=8) kwargs = dict(output_dir=output_dir, train_len=8, fp16=True)
with mockenv_context(**self.dist_env_1_gpu): ds_config_zero3_dict = self.get_config_dict("zero3")
ds_config_zero3_dict = self.get_config_dict("zero3") ds_config_zero2_dict = self.get_config_dict("zero2")
ds_config_zero2_dict = self.get_config_dict("zero2")
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs) trainer = get_regression_trainer(deepspeed=ds_config_zero3_dict, **kwargs)
self.assertTrue(is_deepspeed_zero3_enabled()) self.assertTrue(is_deepspeed_zero3_enabled())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment