Unverified Commit f086652b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add option to log only once in multinode training (#11819)

* Add option to long only once in multinode training

* Use an alternate property
parent b8344a27
...@@ -44,7 +44,7 @@ from transformers import ( ...@@ -44,7 +44,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.testing_utils import CaptureLogger from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -202,7 +202,7 @@ def main(): ...@@ -202,7 +202,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -210,7 +210,7 @@ def main(): ...@@ -210,7 +210,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -43,7 +43,7 @@ from transformers import ( ...@@ -43,7 +43,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -211,7 +211,7 @@ def main(): ...@@ -211,7 +211,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -219,7 +219,7 @@ def main(): ...@@ -219,7 +219,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
XLNetLMHeadModel, XLNetLMHeadModel,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -208,7 +208,7 @@ def main(): ...@@ -208,7 +208,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -216,7 +216,7 @@ def main(): ...@@ -216,7 +216,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -41,7 +41,7 @@ from transformers import ( ...@@ -41,7 +41,7 @@ from transformers import (
) )
from transformers.file_utils import PaddingStrategy from transformers.file_utils import PaddingStrategy
from transformers.tokenization_utils_base import PreTrainedTokenizerBase from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -235,7 +235,7 @@ def main(): ...@@ -235,7 +235,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -243,7 +243,7 @@ def main(): ...@@ -243,7 +243,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -40,7 +40,7 @@ from transformers import ( ...@@ -40,7 +40,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from utils_qa import postprocess_qa_predictions from utils_qa import postprocess_qa_predictions
...@@ -228,7 +228,7 @@ def main(): ...@@ -228,7 +228,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -236,7 +236,7 @@ def main(): ...@@ -236,7 +236,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -39,7 +39,7 @@ from transformers import ( ...@@ -39,7 +39,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
from utils_qa import postprocess_qa_predictions_with_beam_search from utils_qa import postprocess_qa_predictions_with_beam_search
...@@ -227,7 +227,7 @@ def main(): ...@@ -227,7 +227,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -235,7 +235,7 @@ def main(): ...@@ -235,7 +235,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -41,7 +41,7 @@ from transformers import ( ...@@ -41,7 +41,7 @@ from transformers import (
set_seed, set_seed,
) )
from transformers.file_utils import is_offline_mode from transformers.file_utils import is_offline_mode
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -284,7 +284,7 @@ def main(): ...@@ -284,7 +284,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -292,7 +292,7 @@ def main(): ...@@ -292,7 +292,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
......
...@@ -40,7 +40,7 @@ from transformers import ( ...@@ -40,7 +40,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -216,7 +216,7 @@ def main(): ...@@ -216,7 +216,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -224,7 +224,7 @@ def main(): ...@@ -224,7 +224,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -40,7 +40,7 @@ from transformers import ( ...@@ -40,7 +40,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -186,7 +186,7 @@ def main(): ...@@ -186,7 +186,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -195,7 +195,7 @@ def main(): ...@@ -195,7 +195,7 @@ def main():
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -40,7 +40,7 @@ from transformers import ( ...@@ -40,7 +40,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -201,7 +201,7 @@ def main(): ...@@ -201,7 +201,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -209,7 +209,7 @@ def main(): ...@@ -209,7 +209,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
...@@ -44,7 +44,7 @@ from transformers import ( ...@@ -44,7 +44,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -268,7 +268,7 @@ def main(): ...@@ -268,7 +268,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -276,7 +276,7 @@ def main(): ...@@ -276,7 +276,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
......
...@@ -1781,21 +1781,16 @@ class Trainer: ...@@ -1781,21 +1781,16 @@ class Trainer:
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
machines) main process. machines) main process.
""" """
if is_torch_tpu_available(): return self.args.local_process_index == 0
return xm.is_master_ordinal(local=True)
elif is_sagemaker_mp_enabled():
return smp.local_rank() == 0
else:
return self.args.local_rank in [-1, 0]
def is_world_process_zero(self) -> bool: def is_world_process_zero(self) -> bool:
""" """
Whether or not this process is the global main process (when training in a distributed fashion on several Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process). machines, this is only going to be :obj:`True` for one process).
""" """
if is_torch_tpu_available(): # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
return xm.is_master_ordinal(local=False) # process index.
elif is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
return smp.rank() == 0 return smp.rank() == 0
else: else:
return self.args.process_index == 0 return self.args.process_index == 0
......
...@@ -316,6 +316,8 @@ class TrainingArguments: ...@@ -316,6 +316,8 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See :class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details. details.
log_on_each_node (:obj:`bool`, `optional`, defaults to :obj:`True`):
In multinode distributed training, whether to log once per node, or only on the main node.
""" """
output_dir: str = field( output_dir: str = field(
...@@ -559,6 +561,12 @@ class TrainingArguments: ...@@ -559,6 +561,12 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."}, metadata={"help": "The path to a folder with a valid checkpoint for your model."},
) )
log_on_each_node: bool = field(
default=True,
metadata={
"help": "When doing a multinode distributed training, whether to log once per node or just once on the main node."
},
)
_n_gpu: int = field(init=False, repr=False, default=-1) _n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field( mp_parameters: str = field(
default="", default="",
...@@ -834,7 +842,7 @@ class TrainingArguments: ...@@ -834,7 +842,7 @@ class TrainingArguments:
@torch_required @torch_required
def process_index(self): def process_index(self):
""" """
The number of processes used in parallel. The index of the current process used.
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
return xm.get_ordinal() return xm.get_ordinal()
...@@ -846,6 +854,35 @@ class TrainingArguments: ...@@ -846,6 +854,35 @@ class TrainingArguments:
return torch.distributed.get_rank() return torch.distributed.get_rank()
return 0 return 0
@property
@torch_required
def local_process_index(self):
"""
The index of the local process used.
"""
if is_torch_tpu_available():
return xm.get_ordinal(local=True)
elif is_sagemaker_mp_enabled():
return smp.local_rank()
elif is_sagemaker_dp_enabled():
return sm_dist.get_rank()
elif self.local_rank != -1:
return self.local_rank
return 0
@property
def should_log(self):
"""
Whether or not the current process should produce log.
"""
if self.log_on_each_node:
return self.local_process_index == 0
else:
if is_sagemaker_mp_enabled():
return smp.rank() == 0
else:
return self.process_index == 0
@property @property
def place_model_on_device(self): def place_model_on_device(self):
""" """
......
...@@ -43,7 +43,7 @@ from transformers import ( ...@@ -43,7 +43,7 @@ from transformers import (
default_data_collator, default_data_collator,
set_seed, set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -226,7 +226,7 @@ def main(): ...@@ -226,7 +226,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -234,7 +234,7 @@ def main(): ...@@ -234,7 +234,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
......
...@@ -42,7 +42,7 @@ from transformers import ( # Trainer,; TrainingArguments, ...@@ -42,7 +42,7 @@ from transformers import ( # Trainer,; TrainingArguments,
# Will import SageMaker Model parallelism specific Trainer # Will import SageMaker Model parallelism specific Trainer
from transformers.sagemaker import SageMakerTrainer as Trainer from transformers.sagemaker import SageMakerTrainer as Trainer
from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments from transformers.sagemaker import SageMakerTrainingArguments as TrainingArguments
from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version from transformers.utils import check_min_version
...@@ -210,7 +210,7 @@ def main(): ...@@ -210,7 +210,7 @@ def main():
datefmt="%m/%d/%Y %H:%M:%S", datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)], handlers=[logging.StreamHandler(sys.stdout)],
) )
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
# Log on each process the small summary: # Log on each process the small summary:
logger.warning( logger.warning(
...@@ -218,7 +218,7 @@ def main(): ...@@ -218,7 +218,7 @@ def main():
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
) )
# Set the verbosity to info of the Transformers logger (on main process only): # Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank): if training_args.should_log:
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format() transformers.utils.logging.enable_explicit_format()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment