Unverified Commit 17e1a1f6 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] use logger in microbatches module (#1302)

parent 4506a687
...@@ -17,6 +17,11 @@ from abc import ABC ...@@ -17,6 +17,11 @@ from abc import ABC
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, List from typing import Optional, List
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
def build_num_microbatches_calculator( def build_num_microbatches_calculator(
rank: int, rank: int,
...@@ -31,8 +36,8 @@ def build_num_microbatches_calculator( ...@@ -31,8 +36,8 @@ def build_num_microbatches_calculator(
global_batch_size, micro_batch_size, data_parallel_size global_batch_size, micro_batch_size, data_parallel_size
) )
if rank == 0: if rank == 0:
print( _logger.info(
"setting number of micro-batches to constant {}".format(num_microbatches_calculator.get()), flush=True "setting number of micro-batches to constant {}".format(num_microbatches_calculator.get())
) )
else: else:
...@@ -45,7 +50,7 @@ def build_num_microbatches_calculator( ...@@ -45,7 +50,7 @@ def build_num_microbatches_calculator(
batch_size_increment = int(rampup_batch_size[1]) batch_size_increment = int(rampup_batch_size[1])
ramup_samples = int(rampup_batch_size[2]) ramup_samples = int(rampup_batch_size[2])
if rank == 0: if rank == 0:
print( _logger.info(
"will use batch size rampup starting from global batch " "will use batch size rampup starting from global batch "
"size {} to global batch size {} with batch size increments " "size {} to global batch size {} with batch size increments "
"{} over {} samples.".format( "{} over {} samples.".format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment