Unverified Commit ce2fef2a authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer / deepspeed] fix hyperparameter_search (#16740)

* [trainer / deepspeed] fix hyperparameter_search

* require optuna

* style

* oops

* add dep in the right place

* create deepspeed-testing dep group

* Trigger CI
parent 1b7de41a
......@@ -157,7 +157,7 @@ jobs:
apt -y update && apt install -y libaio-dev libsndfile1-dev git espeak-ng
pip install --upgrade pip
pip install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html -U
pip install .[testing,deepspeed]
pip install .[deepspeed-testing]
pip install https://github.com/kpu/kenlm/archive/master.zip
pip install git+https://github.com/microsoft/DeepSpeed
......
......@@ -384,7 +384,7 @@ jobs:
run: |
apt -y update && apt install -y libaio-dev
pip install --upgrade pip
pip install .[testing,deepspeed]
pip install .[deepspeed-testing]
- name: Are GPUs recognized by our DL frameworks
run: |
......
......@@ -9,7 +9,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip
ARG REF=main
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF
RUN python3 -m pip install --no-cache-dir -e ./transformers[testing,deepspeed]
RUN python3 -m pip install --no-cache-dir -e ./transformers[deepspeed-testing]
RUN git clone https://github.com/microsoft/DeepSpeed && cd DeepSpeed && rm -rf build && \
DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 python3 -m pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1
......
......@@ -290,6 +290,8 @@ extras["testing"] = (
+ extras["modelcreation"]
)
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"]
extras["quality"] = deps_list("black", "isort", "flake8", "GitPython", "hf-doc-builder")
extras["all"] = (
......
......@@ -976,9 +976,10 @@ class Trainer:
logger.info(f"W&B Sweep parameters: {trial}")
if self.args.deepspeed:
# Rebuild the deepspeed config to reflect the updated training parameters
from transformers.deepspeed import HfDeepSpeedConfig
from transformers.deepspeed import HfTrainerDeepSpeedConfig
self.args.hf_deepspeed_config = HfDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args)
def _report_to_hp_search(
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
......
......@@ -34,6 +34,7 @@ from transformers.testing_utils import (
get_gpu_count,
mockenv_context,
require_deepspeed,
require_optuna,
require_torch_gpu,
require_torch_multi_gpu,
slow,
......@@ -363,6 +364,33 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
@require_optuna
def test_hyperparameter_search(self):
with mockenv_context(**self.dist_env_1_gpu):
ds_config_zero3_dict = self.get_config_dict(ZERO3)
# hyperparameter_search requires model_init() to recreate the model for each trial
def model_init():
config = RegressionModelConfig(a=0, b=0, double_output=False)
model = RegressionPreTrainedModel(config)
return model
trainer = get_regression_trainer(
local_rank=0,
fp16=True,
model_init=model_init,
deepspeed=ds_config_zero3_dict,
)
n_trials = 3
with CaptureLogger(deepspeed_logger) as cl:
with CaptureStd() as cs:
trainer.hyperparameter_search(direction="maximize", n_trials=n_trials)
self.assertIn("DeepSpeed info", cl.out, "expected DeepSpeed logger output but got none")
self.assertIn(f"Trial {n_trials-1} finished with value", cs.err, "expected hyperparameter_search output")
self.assertIn("Best is trial", cs.err, "expected hyperparameter_search output")
# --- These tests need to run on both zero stages --- #
@parameterized.expand(params, name_func=parameterized_custom_name_func)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment