Unverified Commit dad414d5 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[trainer + examples] set log level from CLI (#12276)



* set log level from CLI

* add log_level_replica + test + extended docs

* cleanup

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename datasets objects to allow datasets module

* improve the doc

* style

* doc improve
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a4ed074d
......@@ -119,6 +119,74 @@ TFTrainingArguments
:members:
Logging
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By default :class:`~transformers.Trainer` will use ``logging.INFO`` for the main process and ``logging.WARNING`` for
the replicas if any.
These defaults can be overridden to use any of the 5 ``logging`` levels with :class:`~transformers.TrainingArguments`'s
arguments:
- ``log_level`` - for the main process
- ``log_level_replica`` - for the replicas
Further, if :class:`~transformers.TrainingArguments`'s ``log_on_each_node`` is set to ``False`` only the main node will
use the log level settings for its main process, all other nodes will use the log level settings for replicas.
Note that :class:`~transformers.Trainer` is going to set ``transformers``'s log level separately for each node in its
:meth:`~transformers.Trainer.__init__`. So you may want to set this sooner (see the next example) if you tap into other
``transformers`` functionality before creating the :class:`~transformers.Trainer` object.
Here is an example of how this can be used in an application:
.. code-block:: python
[...]
logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# set the main code and the modules it uses to the same log-level according to the node
log_level = training_args.get_node_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
trainer = Trainer(...)
And then if you only want to see warnings on the main node and all other nodes to not print any most likely duplicated
warnings you could run it as:
.. code-block:: bash
my_app.py ... --log_level warning --log_level_replica error
In the multi-node environment if you also don't want the logs to repeat for each node's main process, you will want to
change the above to:
.. code-block:: bash
my_app.py ... --log_level warning --log_level_replica error --log_on_each_node 0
and then only the main process of the first node will log at the "warning" level, and all other processes on the main
node and all processes on other nodes will log at the "error" level.
If you need your application to be as quiet as possible you could do:
.. code-block:: bash
my_app.py ... --log_level error --log_level_replica error --log_on_each_node 0
(add ``--log_on_each_node 0`` if on multi-node environment)
Randomness
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -24,6 +24,7 @@ import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
import numpy as np
from datasets import load_dataset, load_metric
......@@ -243,16 +244,17 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
log_level = training_args.get_node_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if training_args.should_log:
transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}")
if data_args.source_prefix is None and model_args.model_name_or_path in [
......@@ -296,7 +298,9 @@ def main():
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
)
else:
data_files = {}
if data_args.train_file is not None:
......@@ -308,7 +312,7 @@ def main():
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
......@@ -356,11 +360,11 @@ def main():
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
if training_args.do_train:
column_names = datasets["train"].column_names
column_names = raw_datasets["train"].column_names
elif training_args.do_eval:
column_names = datasets["validation"].column_names
column_names = raw_datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
column_names = raw_datasets["test"].column_names
else:
logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
return
......@@ -418,9 +422,9 @@ def main():
return model_inputs
if training_args.do_train:
if "train" not in datasets:
if "train" not in raw_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = datasets["train"]
train_dataset = raw_datasets["train"]
if data_args.max_train_samples is not None:
train_dataset = train_dataset.select(range(data_args.max_train_samples))
train_dataset = train_dataset.map(
......@@ -434,9 +438,9 @@ def main():
if training_args.do_eval:
max_target_length = data_args.val_max_target_length
if "validation" not in datasets:
if "validation" not in raw_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation"]
eval_dataset = raw_datasets["validation"]
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
......@@ -450,9 +454,9 @@ def main():
if training_args.do_predict:
max_target_length = data_args.val_max_target_length
if "test" not in datasets:
if "test" not in raw_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_dataset = datasets["test"]
predict_dataset = raw_datasets["test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
......
......@@ -290,6 +290,10 @@ class Trainer:
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
# set the correct log level depending on the node
log_level = args.get_node_log_level()
logging.set_verbosity(log_level)
# force device and distributed setup init explicitly
args._setup_devices
......
......@@ -905,12 +905,12 @@ def log_metrics(self, split, metrics):
if not self.is_world_process_zero():
return
logger.info(f"***** {split} metrics *****")
print(f"***** {split} metrics *****")
metrics_formatted = self.metrics_format(metrics)
k_width = max(len(str(x)) for x in metrics_formatted.keys())
v_width = max(len(str(x)) for x in metrics_formatted.values())
for key in sorted(metrics_formatted.keys()):
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
print(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
def save_metrics(self, split, metrics, combined=True):
......
......@@ -48,6 +48,8 @@ if is_sagemaker_mp_enabled():
logger = logging.get_logger(__name__)
log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)
def default_logdir() -> str:
......@@ -144,6 +146,15 @@ class TrainingArguments:
warmup_steps (:obj:`int`, `optional`, defaults to 0):
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
:obj:`warmup_ratio`.
log_level (:obj:`str`, `optional`, defaults to ``passive``):
Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the
application set the level.
log_level_replica (:obj:`str`, `optional`, defaults to ``passive``):
Logger log level to use on replicas. Same choices as ``log_level``"
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
In multinode distributed training, whether to log using :obj:`log_level` once per node, or only on the main
node.
logging_dir (:obj:`str`, `optional`):
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
`runs/**CURRENT_DATETIME_HOSTNAME**`.
......@@ -316,8 +327,6 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
In multinode distributed training, whether to log once per node, or only on the main node.
"""
output_dir: str = field(
......@@ -397,6 +406,26 @@ class TrainingArguments:
)
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
log_level: Optional[str] = field(
default="passive",
metadata={
"help": "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug', 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and lets the application set the level. Defaults to 'passive'.",
"choices": trainer_log_levels.keys(),
},
)
log_level_replica: Optional[str] = field(
default="passive",
metadata={
"help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
"choices": trainer_log_levels.keys(),
},
)
log_on_each_node: bool = field(
default=True,
metadata={
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
},
)
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field(
default="steps",
......@@ -561,12 +590,6 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
log_on_each_node: bool = field(
default=True,
metadata={
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
},
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
......@@ -580,6 +603,8 @@ class TrainingArguments:
if env_local_rank != -1 and env_local_rank != self.local_rank:
self.local_rank = env_local_rank
self.log_level = trainer_log_levels[self.log_level]
# expand paths, if not os.makedirs("~/bar") will make directory
# in the current directory instead of the actual home
#  see https://github.com/huggingface/transformers/issues/10628
......@@ -889,6 +914,11 @@ class TrainingArguments:
else:
return self.process_index == 0
def get_node_log_level(self):
log_level_main_node = logging.INFO if self.log_level == -1 else self.log_level
log_level_replica_node = logging.WARNING if self.log_level_replica == -1 else self.log_level_replica
return log_level_main_node if self.should_log else log_level_replica_node
@property
def place_model_on_device(self):
"""
......
......@@ -102,6 +102,10 @@ def _reset_library_root_logger() -> None:
_default_handler = None
def get_log_levels_dict():
return log_levels
def get_logger(name: Optional[str] = None) -> logging.Logger:
"""
Return a logger with the specified name.
......
......@@ -27,12 +27,20 @@ import numpy as np
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import AutoTokenizer, IntervalStrategy, PretrainedConfig, TrainingArguments, is_torch_available
from transformers import (
AutoTokenizer,
IntervalStrategy,
PretrainedConfig,
TrainingArguments,
is_torch_available,
logging,
)
from transformers.file_utils import WEIGHTS_NAME
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
TestCasePlus,
get_gpu_count,
get_tests_dir,
......@@ -614,6 +622,29 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
self.assertFalse(torch.allclose(trainer.model.b, b))
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
def test_log_level(self):
# testing only --log_level (--log_level_replica requires multiple nodes)
logger = logging.get_logger()
log_info_string = "Running training"
# test with the default log level - should be info and thus log
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer()
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with low log level - lower than info
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="debug")
trainer.train()
self.assertIn(log_info_string, cl.out)
# test with high log level - should be quiet
with CaptureLogger(logger) as cl:
trainer = get_regression_trainer(log_level="error")
trainer.train()
self.assertNotIn(log_info_string, cl.out)
def test_model_init(self):
train_dataset = RegressionDataset()
args = TrainingArguments("./regression", learning_rate=0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment