Commit 45504541 authored by mohammad's avatar mohammad
Browse files

Merge branch 'blendable_dataset' into refactor_learning_rate

parents ff12df6b 98989693
...@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -136,14 +136,16 @@ def parse_args(extra_args_provider=None, defaults={},
def _print_args(args): def _print_args(args):
"""Print arguments.""" """Print arguments."""
if args.rank == 0: if args.rank == 0:
print('-------------------- arguments --------------------', flush=True) print('------------------------ arguments ------------------------',
flush=True)
str_list = [] str_list = []
for arg in vars(args): for arg in vars(args):
dots = '.' * (32 - len(arg)) dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()): for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True) print(arg, flush=True)
print('---------------- end of arguments ----------------', flush=True) print('-------------------- end of arguments ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg): def _check_arg_is_not_none(args, arg):
...@@ -401,7 +403,10 @@ def _add_data_args(parser): ...@@ -401,7 +403,10 @@ def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader') group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', nargs='*', default=None, group.add_argument('--data-path', nargs='*', default=None,
help='Path to combined dataset to split.') help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path): ...@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler, def save_checkpoint(iteration, model, optimizer, lr_scheduler):
consumed_train_samples=None, consumed_valid_samples=None):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, ...@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler,
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0 state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
if consumed_train_samples:
state_dict['consumed_train_samples'] = consumed_train_samples
if consumed_valid_samples:
state_dict['consumed_valid_samples'] = consumed_valid_samples
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
...@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
checkpoint_name)) checkpoint_name))
sys.exit() sys.exit()
if 'consumed_train_samples' in state_dict: # Check arguments.
assert args.consumed_train_samples == 0 assert args.consumed_train_samples == 0
args.consumed_train_samples = state_dict['consumed_train_samples']
if 'consumed_valid_samples' in state_dict:
assert args.consumed_valid_samples == 0 assert args.consumed_valid_samples == 0
args.consumed_valid_samples = state_dict['consumed_valid_samples']
# Check arguments.
if 'args' in state_dict: if 'args' in state_dict:
checkpoint_args = state_dict['args'] checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args) check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0)
args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0)
else: else:
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
......
...@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index, ...@@ -60,7 +60,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening. // Determine where the max error in sampling is happening.
double sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0); auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0; int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double - double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(current_samples[0]); static_cast<double>(current_samples[0]);
...@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index, ...@@ -86,7 +86,7 @@ void build_blending_indices(py::array_t<uint8_t>& dataset_index,
if (verbose) { if (verbose) {
std::cout << " > sample ratios:" << std::endl; std::cout << " > sample ratios:" << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
double ratio = static_cast<double>(current_samples[dataset_idx]) / auto ratio = static_cast<double>(current_samples[dataset_idx]) /
static_cast<double>(size); static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " << std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
......
...@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration, False) iteration, False)
if args.save and iteration != 0: if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, save_checkpoint(iteration, model, optimizer, lr_scheduler)
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
if args.do_test: if args.do_test:
# Run on test data. # Run on test data.
...@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing # Checkpointing
if args.save and args.save_interval and \ if args.save and args.save_interval and \
iteration % args.save_interval == 0: iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, save_checkpoint(iteration, model, optimizer, lr_scheduler)
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
# Evaluation # Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment