Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
460b8443
Unverified
Commit
460b8443
authored
Jun 05, 2023
by
Sourab Mangrulkar
Committed by
GitHub
Jun 05, 2023
Browse files
fix trainer slow tests related to hyperparam search (#24011)
* fix trainer slow tests * commit 2
parent
3c310897
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
26 deletions
+30
-26
src/transformers/trainer.py
src/transformers/trainer.py
+30
-26
No files found.
src/transformers/trainer.py
View file @
460b8443
...
...
@@ -339,31 +339,7 @@ class Trainer:
self
.
hp_name
=
None
self
.
is_in_train
=
False
# create accelerator object
self
.
accelerator
=
Accelerator
(
deepspeed_plugin
=
self
.
args
.
deepspeed_plugin
,
gradient_accumulation_steps
=
self
.
args
.
gradient_accumulation_steps
,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self
.
is_deepspeed_enabled
=
getattr
(
self
.
accelerator
.
state
,
"deepspeed_plugin"
,
None
)
is
not
None
self
.
is_fsdp_enabled
=
getattr
(
self
.
accelerator
.
state
,
"fsdp_plugin"
,
None
)
is
not
None
# post accelerator creation setup
if
self
.
is_fsdp_enabled
:
fsdp_plugin
=
self
.
accelerator
.
state
.
fsdp_plugin
fsdp_plugin
.
limit_all_gathers
=
self
.
args
.
fsdp_config
.
get
(
"limit_all_gathers"
,
False
)
fsdp_plugin
.
use_orig_params
=
self
.
args
.
fsdp_config
.
get
(
"use_orig_params"
,
False
)
if
self
.
is_deepspeed_enabled
:
if
getattr
(
self
.
args
,
"hf_deepspeed_config"
,
None
)
is
None
:
from
transformers.deepspeed
import
HfTrainerDeepSpeedConfig
ds_plugin
=
self
.
accelerator
.
state
.
deepspeed_plugin
ds_plugin
.
hf_ds_config
=
HfTrainerDeepSpeedConfig
(
ds_plugin
.
hf_ds_config
.
config
)
ds_plugin
.
deepspeed_config
=
ds_plugin
.
hf_ds_config
.
config
ds_plugin
.
hf_ds_config
.
trainer_config_process
(
self
.
args
)
self
.
create_accelerator_and_postprocess
()
# memory metrics - must set up as early as possible
self
.
_memory_tracker
=
TrainerMemoryTracker
(
self
.
args
.
skip_memory_metrics
)
...
...
@@ -1343,7 +1319,8 @@ class Trainer:
self
.
args
.
hf_deepspeed_config
=
HfTrainerDeepSpeedConfig
(
self
.
args
.
deepspeed
)
self
.
args
.
hf_deepspeed_config
.
trainer_config_process
(
self
.
args
)
self
.
accelerator
.
state
.
deepspeed_plugin
=
DeepSpeedPlugin
(
hf_ds_config
=
self
.
hf_deepspeed_config
)
self
.
args
.
deepspeed_plugin
=
DeepSpeedPlugin
(
hf_ds_config
=
self
.
args
.
hf_deepspeed_config
)
self
.
create_accelerator_and_postprocess
()
def
_report_to_hp_search
(
self
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]],
step
:
int
,
metrics
:
Dict
[
str
,
float
]):
if
self
.
hp_search_backend
is
None
or
trial
is
None
:
...
...
@@ -3924,3 +3901,30 @@ class Trainer:
if
not
self
.
repo
.
is_repo_clean
():
self
.
repo
.
git_commit
(
"Add *.sagemaker patterns to .gitignore."
)
self
.
repo
.
git_push
()
def
create_accelerator_and_postprocess
(
self
):
# create accelerator object
self
.
accelerator
=
Accelerator
(
deepspeed_plugin
=
self
.
args
.
deepspeed_plugin
,
gradient_accumulation_steps
=
self
.
args
.
gradient_accumulation_steps
,
)
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self
.
is_deepspeed_enabled
=
getattr
(
self
.
accelerator
.
state
,
"deepspeed_plugin"
,
None
)
is
not
None
self
.
is_fsdp_enabled
=
getattr
(
self
.
accelerator
.
state
,
"fsdp_plugin"
,
None
)
is
not
None
# post accelerator creation setup
if
self
.
is_fsdp_enabled
:
fsdp_plugin
=
self
.
accelerator
.
state
.
fsdp_plugin
fsdp_plugin
.
limit_all_gathers
=
self
.
args
.
fsdp_config
.
get
(
"limit_all_gathers"
,
False
)
fsdp_plugin
.
use_orig_params
=
self
.
args
.
fsdp_config
.
get
(
"use_orig_params"
,
False
)
if
self
.
is_deepspeed_enabled
:
if
getattr
(
self
.
args
,
"hf_deepspeed_config"
,
None
)
is
None
:
from
transformers.deepspeed
import
HfTrainerDeepSpeedConfig
ds_plugin
=
self
.
accelerator
.
state
.
deepspeed_plugin
ds_plugin
.
hf_ds_config
=
HfTrainerDeepSpeedConfig
(
ds_plugin
.
hf_ds_config
.
config
)
ds_plugin
.
deepspeed_config
=
ds_plugin
.
hf_ds_config
.
config
ds_plugin
.
hf_ds_config
.
trainer_config_process
(
self
.
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment