from types import SimpleNamespace from megatron.training.global_vars import set_args from megatron.training.training import build_train_valid_test_data_iterators from tests.unit_tests.test_utilities import Utils def mock_train_valid_test_datasets_provider(train_val_test_num_samples): return 1, 2, 3 def create_test_args(): # Set dummy values for the args. args = SimpleNamespace() args.iteration = 0 args.train_samples = 1 args.train_iters = 1 args.eval_interval = 1 args.eval_iters = 1 args.global_batch_size = 1 args.consumed_train_samples = 1 args.consumed_valid_samples = 1 args.dataloader_type = "external" args.skip_train = False return args class TestTraining: def setup_method(self, method): Utils.initialize_model_parallel(1, 1) args = create_test_args() set_args(args) def test_build_train_valid_test_data_iterators(self): train_iter, valid_iter, test_iter = build_train_valid_test_data_iterators( mock_train_valid_test_datasets_provider ) assert (train_iter, valid_iter, test_iter) == (1, 2, 3) def teardown_method(self, method): Utils.destroy_model_parallel()