Unverified Commit d28b7aa8 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed / testing] reset global state (#17553)

* [deepspeed] fix load_best_model test

* [deepspeed] add state reset on unittest tearDown
parent 34a886fc
......@@ -295,6 +295,12 @@ def set_hf_deepspeed_config(hf_deepspeed_config_obj):
_hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
def unset_hf_deepspeed_config():
# useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method
global _hf_deepspeed_config_weak_ref
_hf_deepspeed_config_weak_ref = None
def is_deepspeed_zero3_enabled():
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().is_zero3()
......
......@@ -25,7 +25,7 @@ import datasets
from parameterized import parameterized
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_available, unset_hf_deepspeed_config
from transformers.testing_utils import (
CaptureLogger,
CaptureStd,
......@@ -161,6 +161,12 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
MASTER_ADDR="localhost", MASTER_PORT=master_port, RANK="0", LOCAL_RANK="0", WORLD_SIZE="1"
)
def tearDown(self):
super().tearDown()
# reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config()
def test_init_zero3_fp16(self):
# test that zero.Init() works correctly under zero3/fp16
ds_config = {
......@@ -229,6 +235,12 @@ class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
zero3=config_zero3,
)
def tearDown(self):
super().tearDown()
# reset the ds config global so that tests state doesn't leak
unset_hf_deepspeed_config()
def get_config_dict(self, stage):
# As some tests modify the dict, always make a copy
return deepcopy(self.ds_config_dict[stage])
......@@ -754,6 +766,25 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
with mockenv_context(**self.dist_env_1_gpu):
args_dict = {
"per_gpu_train_batch_size": 1,
"per_gpu_eval_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"num_train_epochs": 1,
"do_train": True,
"do_eval": True,
"optim": "adafactor",
"evaluation_strategy": "steps",
"eval_steps": 1,
"save_strategy": "steps",
"save_steps": 1,
"load_best_model_at_end": True,
"max_steps": 1,
"deepspeed": ds_config_dict,
}
training_args = TrainingArguments(output_dir, **args_dict)
tokenizer = T5Tokenizer.from_pretrained(T5_TINY)
model = T5ForConditionalGeneration.from_pretrained(T5_TINY)
......@@ -788,26 +819,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
train_dataset, eval_dataset = get_dataset()
args_dict = {
"per_gpu_train_batch_size": 1,
"per_gpu_eval_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"num_train_epochs": 1,
"do_train": True,
"do_eval": True,
"optim": "adafactor",
"evaluation_strategy": "steps",
"eval_steps": 1,
"save_strategy": "steps",
"save_steps": 1,
"load_best_model_at_end": True,
"max_steps": 1,
"deepspeed": ds_config_dict,
}
training_args = TrainingArguments(output_dir, **args_dict)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment