"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "651150cb285f33541e07372819f559c9f5083cc1"
Commit 599e959a authored by Mohammad's avatar Mohammad
Browse files

working on bert

parent 1446bb64
...@@ -223,6 +223,35 @@ def add_validation_args(parser): ...@@ -223,6 +223,35 @@ def add_validation_args(parser):
help='Interval between running evaluation on ' help='Interval between running evaluation on '
'validation set.') 'validation set.')
return parser
def add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', type=str, required=True,
help='Path to combined dataset to split.')
group.add_argument('--split', type=str, required=True,
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90% of data for training, 5% for '
'validation and 5% for test.')
group.add_argument('--vocab-file', type=str, required=True,
help='Path to the vocab file.')
group.add_argument('--seq-length', type=int, required=True,
help="Maximum sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
return parser return parser
######################## ########################
...@@ -290,12 +319,6 @@ def add_training_args_(parser): ...@@ -290,12 +319,6 @@ def add_training_args_(parser):
# Learning rate. # Learning rate.
# model checkpointing
group.add_argument('--resume-dataloader', action='store_true',
help='Resume the dataloader when resuming training. '
'Does not apply to tfrecords dataloader, try resuming'
'with a different seed in this case.')
# distributed training args
# autoresume # autoresume
group.add_argument('--adlr-autoresume', action='store_true', group.add_argument('--adlr-autoresume', action='store_true',
help='enable autoresume on adlr cluster.') help='enable autoresume on adlr cluster.')
...@@ -361,7 +384,7 @@ def add_text_generate_args(parser): ...@@ -361,7 +384,7 @@ def add_text_generate_args(parser):
return parser return parser
def add_data_args(parser): def add_data_args_(parser):
"""Train/valid/test data arguments.""" """Train/valid/test data arguments."""
group = parser.add_argument_group('data', 'data configurations') group = parser.add_argument_group('data', 'data configurations')
...@@ -382,22 +405,13 @@ def add_data_args(parser): ...@@ -382,22 +405,13 @@ def add_data_args(parser):
help='path(s) to the validation data.') help='path(s) to the validation data.')
group.add_argument('--test-data', nargs='*', default=None, group.add_argument('--test-data', nargs='*', default=None,
help='path(s) to the testing data.') help='path(s) to the testing data.')
group.add_argument('--data-path', nargs='+', default=None,
help='path to combined dataset to split')
group.add_argument('--split', default='1000,1,1',
help='comma-separated list of proportions for training,'
' validation, and test split')
group.add_argument('--seq-length', type=int, default=512,
help="Maximum sequence length to process")
group.add_argument('--max-preds-per-seq', type=int, default=None, group.add_argument('--max-preds-per-seq', type=int, default=None,
help='Maximum number of predictions to use per sequence.' help='Maximum number of predictions to use per sequence.'
'Defaults to math.ceil(`--seq-length`*.15/10)*10.' 'Defaults to math.ceil(`--seq-length`*.15/10)*10.'
'MUST BE SPECIFIED IF `--data-loader tfrecords`.') 'MUST BE SPECIFIED IF `--data-loader tfrecords`.')
# arguments for binary data loader # arguments for binary data loader
parser.add_argument('--vocab', type=str, default='vocab.txt',
help='path to vocab file')
parser.add_argument('--data-impl', type=str, default='infer', parser.add_argument('--data-impl', type=str, default='infer',
help='implementation of indexed datasets', help='implementation of indexed datasets',
choices=['lazy', 'cached', 'mmap', 'infer']) choices=['lazy', 'cached', 'mmap', 'infer'])
...@@ -405,12 +419,6 @@ def add_data_args(parser): ...@@ -405,12 +419,6 @@ def add_data_args(parser):
help='Maximum number of samples to plan for, defaults to total iters * batch-size.') help='Maximum number of samples to plan for, defaults to total iters * batch-size.')
parser.add_argument('--data-epochs', type=int, default=None, parser.add_argument('--data-epochs', type=int, default=None,
help='Number of epochs to plan for, defaults to using --max-num-samples') help='Number of epochs to plan for, defaults to using --max-num-samples')
parser.add_argument('--mask-prob', default=0.15, type=float,
help='probability of replacing a token with mask')
parser.add_argument('--short-seq-prob', default=0.1, type=float,
help='probability of producing a short sequence')
parser.add_argument('--skip-mmap-warmup', action='store_true',
help='skip warming up mmap files')
# arguments for numpy data loader # arguments for numpy data loader
group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt', group.add_argument('--input-data-sizes-file', type=str, default='sizes.txt',
...@@ -432,8 +440,6 @@ def add_data_args(parser): ...@@ -432,8 +440,6 @@ def add_data_args(parser):
help='Dataset content consists of documents where ' help='Dataset content consists of documents where '
'each document consists of newline separated sentences') 'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str, group.add_argument('--tokenizer-model-type', type=str,
default='bert-large-uncased', default='bert-large-uncased',
help="Model type to use for sentencepiece tokenization \ help="Model type to use for sentencepiece tokenization \
...@@ -470,6 +476,7 @@ def get_args_(extra_args_provider=None): ...@@ -470,6 +476,7 @@ def get_args_(extra_args_provider=None):
parser = add_mixed_precision_args(parser) parser = add_mixed_precision_args(parser)
parser = add_distributed_args(parser) parser = add_distributed_args(parser)
parser = add_validation_args(parser) parser = add_validation_args(parser)
parser = add_data_args(parser)
#parser.print_help() #parser.print_help()
#exit() #exit()
...@@ -479,7 +486,7 @@ def get_args_(extra_args_provider=None): ...@@ -479,7 +486,7 @@ def get_args_(extra_args_provider=None):
parser = add_training_args_(parser) parser = add_training_args_(parser)
parser = add_evaluation_args(parser) parser = add_evaluation_args(parser)
parser = add_text_generate_args(parser) parser = add_text_generate_args(parser)
parser = add_data_args(parser) parser = add_data_args_(parser)
if extra_args_provider is not None: if extra_args_provider is not None:
parser = extra_args_provider(parser) parser = extra_args_provider(parser)
......
...@@ -486,20 +486,19 @@ def evaluate_and_print_results(prefix, forward_step_func, ...@@ -486,20 +486,19 @@ def evaluate_and_print_results(prefix, forward_step_func,
def get_train_val_test_data_iterators(train_data, val_data, test_data, args): def get_train_val_test_data_iterators(train_data, val_data, test_data, args):
"""Build train/validation/test iterators""" """Build train/validation/test iterators"""
# If resume is on, shift the start iterations. # Shift the start iterations.
if args.resume_dataloader: if train_data is not None:
if train_data is not None: train_data.batch_sampler.start_iter = args.iteration % \
train_data.batch_sampler.start_iter = args.iteration % \ len(train_data)
len(train_data) print_rank_0('setting training data start iteration to {}'.
print_rank_0('setting training data start iteration to {}'. format(train_data.batch_sampler.start_iter))
format(train_data.batch_sampler.start_iter)) if val_data is not None:
if val_data is not None: start_iter_val = (args.iteration // args.eval_interval) * \
start_iter_val = (args.iteration // args.eval_interval) * \ args.eval_iters
args.eval_iters val_data.batch_sampler.start_iter = start_iter_val % \
val_data.batch_sampler.start_iter = start_iter_val % \ len(val_data)
len(val_data) print_rank_0('setting validation data start iteration to {}'.
print_rank_0('setting validation data start iteration to {}'. format(val_data.batch_sampler.start_iter))
format(val_data.batch_sampler.start_iter))
if train_data is not None: if train_data is not None:
train_data_iterator = iter(train_data) train_data_iterator = iter(train_data)
......
...@@ -118,17 +118,6 @@ def get_train_val_test_data(args): ...@@ -118,17 +118,6 @@ def get_train_val_test_data(args):
print_rank_0('> building train, validation, and test datasets ' print_rank_0('> building train, validation, and test datasets '
'for BERT ...') 'for BERT ...')
if args.data_loader is None:
args.data_loader = 'binary'
if args.data_loader != 'binary':
print('Unsupported {} data loader for BERT.'.format(
args.data_loader))
exit(1)
if not args.data_path:
print('BERT only supports a unified dataset specified '
'with --data-path')
exit(1)
data_parallel_size = mpu.get_data_parallel_world_size() data_parallel_size = mpu.get_data_parallel_world_size()
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * data_parallel_size global_batch_size = args.batch_size * data_parallel_size
...@@ -137,7 +126,7 @@ def get_train_val_test_data(args): ...@@ -137,7 +126,7 @@ def get_train_val_test_data(args):
train_iters = args.train_iters train_iters = args.train_iters
eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
test_iters = args.eval_iters test_iters = args.eval_iters
train_val_test_num_samples = [args.train_iters * global_batch_size, train_val_test_num_samples = [train_iters * global_batch_size,
eval_iters * global_batch_size, eval_iters * global_batch_size,
test_iters * global_batch_size] test_iters * global_batch_size]
print_rank_0(' > datasets target sizes (minimum size):') print_rank_0(' > datasets target sizes (minimum size):')
...@@ -145,10 +134,9 @@ def get_train_val_test_data(args): ...@@ -145,10 +134,9 @@ def get_train_val_test_data(args):
print_rank_0(' validation: {}'.format(train_val_test_num_samples[1])) print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
print_rank_0(' test: {}'.format(train_val_test_num_samples[2])) print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
assert len(args.data_path) == 1
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
vocab_file=args.vocab, vocab_file=args.vocab_file,
data_prefix=args.data_path[0], data_prefix=args.data_path,
data_impl=args.data_impl, data_impl=args.data_impl,
splits_string=args.split, splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
...@@ -156,7 +144,7 @@ def get_train_val_test_data(args): ...@@ -156,7 +144,7 @@ def get_train_val_test_data(args):
masked_lm_prob=args.mask_prob, masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=args.skip_mmap_warmup) skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT datasets ...") print_rank_0("> finished creating BERT datasets ...")
def make_data_loader_(dataset): def make_data_loader_(dataset):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment