Commit e1b3fc8c authored by zihanl's avatar zihanl
Browse files

update training.py

parent d0d83fe1
......@@ -53,6 +53,7 @@ from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
def print_datetime(string):
"""Note that this call will sync across all ranks."""
torch.distributed.barrier()
......@@ -848,7 +849,7 @@ def build_train_valid_test_data_iterators(
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
# Number of train/valid/test samples.
if args.train_samples:
train_samples = args.train_samples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment