"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b23d3a5ad4aa08decd10671f85be5950767dd052"
Unverified Commit 460b8443 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix trainer slow tests related to hyperparam search (#24011)

* fix trainer slow tests

* commit 2
parent 3c310897
...@@ -339,31 +339,7 @@ class Trainer: ...@@ -339,31 +339,7 @@ class Trainer:
self.hp_name = None self.hp_name = None
self.is_in_train = False self.is_in_train = False
# create accelerator object self.create_accelerator_and_postprocess()
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
# memory metrics - must set up as early as possible # memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
...@@ -1343,7 +1319,8 @@ class Trainer: ...@@ -1343,7 +1319,8 @@ class Trainer:
self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
self.args.hf_deepspeed_config.trainer_config_process(self.args) self.args.hf_deepspeed_config.trainer_config_process(self.args)
self.accelerator.state.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config) self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
self.create_accelerator_and_postprocess()
def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
if self.hp_search_backend is None or trial is None: if self.hp_search_backend is None or trial is None:
...@@ -3924,3 +3901,30 @@ class Trainer: ...@@ -3924,3 +3901,30 @@ class Trainer:
if not self.repo.is_repo_clean(): if not self.repo.is_repo_clean():
self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
self.repo.git_push() self.repo.git_push()
def create_accelerator_and_postprocess(self):
# create accelerator object
self.accelerator = Accelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig
ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment