Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
82d46feb
Unverified
Commit
82d46feb
authored
Jan 22, 2021
by
Sylvain Gugger
Committed by
GitHub
Jan 22, 2021
Browse files
Add `report_to` training arguments to control the reporting integrations used (#9735)
parent
411c5821
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
32 deletions
+47
-32
src/transformers/integrations.py
src/transformers/integrations.py
+33
-0
src/transformers/trainer.py
src/transformers/trainer.py
+3
-32
src/transformers/training_args.py
src/transformers/training_args.py
+11
-0
No files found.
src/transformers/integrations.py
View file @
82d46feb
...
@@ -225,6 +225,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
...
@@ -225,6 +225,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
return
best_run
return
best_run
def
get_available_reporting_integrations
():
integrations
=
[]
if
is_azureml_available
():
integrations
.
append
(
"azure_ml"
)
if
is_comet_available
():
integrations
.
append
(
"comet_ml"
)
if
is_mlflow_available
():
integrations
.
append
(
"mlflow"
)
if
is_tensorboard_available
():
integrations
.
append
(
"tensorboard"
)
if
is_wandb_available
():
integrations
.
append
(
"wandb"
)
return
integrations
def
rewrite_logs
(
d
):
def
rewrite_logs
(
d
):
new_d
=
{}
new_d
=
{}
eval_prefix
=
"eval_"
eval_prefix
=
"eval_"
...
@@ -757,3 +772,21 @@ class MLflowCallback(TrainerCallback):
...
@@ -757,3 +772,21 @@ class MLflowCallback(TrainerCallback):
# not let you start a new run before the previous one is killed
# not let you start a new run before the previous one is killed
if
self
.
_ml_flow
.
active_run
is
not
None
:
if
self
.
_ml_flow
.
active_run
is
not
None
:
self
.
_ml_flow
.
end_run
(
status
=
"KILLED"
)
self
.
_ml_flow
.
end_run
(
status
=
"KILLED"
)
INTEGRATION_TO_CALLBACK
=
{
"azure_ml"
:
AzureMLCallback
,
"comet_ml"
:
CometCallback
,
"mlflow"
:
MLflowCallback
,
"tensorboard"
:
TensorBoardCallback
,
"wandb"
:
WandbCallback
,
}
def
get_reporting_integration_callbacks
(
report_to
):
for
integration
in
report_to
:
if
integration
not
in
INTEGRATION_TO_CALLBACK
:
raise
ValueError
(
f
"
{
integration
}
is not supported, only
{
', '
.
join
(
INTEGRATION_TO_CALLBACK
.
keys
())
}
are supported."
)
return
[
INTEGRATION_TO_CALLBACK
[
integration
]
for
integration
in
report_to
]
src/transformers/trainer.py
View file @
82d46feb
...
@@ -31,15 +31,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
...
@@ -31,15 +31,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
# Integrations must be imported before ML frameworks:
# Integrations must be imported before ML frameworks:
from
.integrations
import
(
# isort: split
from
.integrations
import
(
# isort: split
default_hp_search_backend
,
default_hp_search_backend
,
get_reporting_integration_callbacks
,
hp_params
,
hp_params
,
is_azureml_available
,
is_comet_available
,
is_fairscale_available
,
is_fairscale_available
,
is_mlflow_available
,
is_optuna_available
,
is_optuna_available
,
is_ray_tune_available
,
is_ray_tune_available
,
is_tensorboard_available
,
is_wandb_available
,
run_hp_search_optuna
,
run_hp_search_optuna
,
run_hp_search_ray
,
run_hp_search_ray
,
init_deepspeed
,
init_deepspeed
,
...
@@ -124,32 +120,6 @@ if is_torch_tpu_available():
...
@@ -124,32 +120,6 @@ if is_torch_tpu_available():
import
torch_xla.debug.metrics
as
met
import
torch_xla.debug.metrics
as
met
import
torch_xla.distributed.parallel_loader
as
pl
import
torch_xla.distributed.parallel_loader
as
pl
if
is_tensorboard_available
():
from
.integrations
import
TensorBoardCallback
DEFAULT_CALLBACKS
.
append
(
TensorBoardCallback
)
if
is_wandb_available
():
from
.integrations
import
WandbCallback
DEFAULT_CALLBACKS
.
append
(
WandbCallback
)
if
is_comet_available
():
from
.integrations
import
CometCallback
DEFAULT_CALLBACKS
.
append
(
CometCallback
)
if
is_mlflow_available
():
from
.integrations
import
MLflowCallback
DEFAULT_CALLBACKS
.
append
(
MLflowCallback
)
if
is_azureml_available
():
from
.integrations
import
AzureMLCallback
DEFAULT_CALLBACKS
.
append
(
AzureMLCallback
)
if
is_fairscale_available
():
if
is_fairscale_available
():
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.nn.data_parallel
import
ShardedDataParallel
as
ShardedDDP
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
...
@@ -300,7 +270,8 @@ class Trainer:
...
@@ -300,7 +270,8 @@ class Trainer:
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
)
callbacks
=
DEFAULT_CALLBACKS
if
callbacks
is
None
else
DEFAULT_CALLBACKS
+
callbacks
default_callbacks
=
DEFAULT_CALLBACKS
+
get_reporting_integration_callbacks
(
self
.
args
.
report_to
)
callbacks
=
default_callbacks
if
callbacks
is
None
else
default_callbacks
+
callbacks
self
.
callback_handler
=
CallbackHandler
(
self
.
callback_handler
=
CallbackHandler
(
callbacks
,
self
.
model
,
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
callbacks
,
self
.
model
,
self
.
tokenizer
,
self
.
optimizer
,
self
.
lr_scheduler
)
)
...
...
src/transformers/training_args.py
View file @
82d46feb
...
@@ -231,6 +231,9 @@ class TrainingArguments:
...
@@ -231,6 +231,9 @@ class TrainingArguments:
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
padding applied and be more efficient). Only useful if applying dynamic padding.
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
"""
"""
output_dir
:
str
=
field
(
output_dir
:
str
=
field
(
...
@@ -413,6 +416,9 @@ class TrainingArguments:
...
@@ -413,6 +416,9 @@ class TrainingArguments:
default
=
False
,
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to group samples of roughly the same length together when batching."
},
metadata
=
{
"help"
:
"Whether or not to group samples of roughly the same length together when batching."
},
)
)
report_to
:
Optional
[
List
[
str
]]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The list of integrations to report the results and logs to."
}
)
_n_gpu
:
int
=
field
(
init
=
False
,
repr
=
False
,
default
=-
1
)
_n_gpu
:
int
=
field
(
init
=
False
,
repr
=
False
,
default
=-
1
)
def
__post_init__
(
self
):
def
__post_init__
(
self
):
...
@@ -434,6 +440,11 @@ class TrainingArguments:
...
@@ -434,6 +440,11 @@ class TrainingArguments:
if
is_torch_available
()
and
self
.
device
.
type
!=
"cuda"
and
self
.
fp16
:
if
is_torch_available
()
and
self
.
device
.
type
!=
"cuda"
and
self
.
fp16
:
raise
ValueError
(
"Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices."
)
raise
ValueError
(
"Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices."
)
if
self
.
report_to
is
None
:
# Import at runtime to avoid a circular import.
from
.integrations
import
get_available_reporting_integrations
self
.
report_to
=
get_available_reporting_integrations
()
def
__repr__
(
self
):
def
__repr__
(
self
):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment