Commit d6485684 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

fixed gpt-2 dataloder

parent b8e0129f
...@@ -115,7 +115,7 @@ def get_train_val_test_data(args): ...@@ -115,7 +115,7 @@ def get_train_val_test_data(args):
if args.data_loader == 'numpy': if args.data_loader == 'numpy':
(train_data, val_data, test_data), num_tokens, \ (train_data, val_data, test_data), num_tokens, \
eod_token = make_gpt2_dataloaders(args) eod_token = make_gpt2_dataloaders(args)
elif args.data_loader == 'raw' or args.data_loader == 'tfrecords' elif args.data_loader == 'raw' or args.data_loader == 'lazy'
data_config = configure_data() data_config = configure_data()
data_config.set_defaults(data_set_type='GPT2', transpose=False) data_config.set_defaults(data_set_type='GPT2', transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply( (train_data, val_data, test_data), tokenizer = data_config.apply(
...@@ -123,6 +123,9 @@ def get_train_val_test_data(args): ...@@ -123,6 +123,9 @@ def get_train_val_test_data(args):
num_tokens = tokenizer.num_tokens num_tokens = tokenizer.num_tokens
eod_token = tokenizer.get_command('eos').Id eod_token = tokenizer.get_command('eos').Id
assert eod_token == tokenizer.get_command('pad').Id assert eod_token == tokenizer.get_command('pad').Id
else:
print("Unsupported data loader for GPT2.")
exit(1)
# pad. # pad.
num_tokens = vocab_size_with_padding(num_tokens, args) num_tokens = vocab_size_with_padding(num_tokens, args)
print_rank_0('> found end-of-document token: {}'.format(eod_token)) print_rank_0('> found end-of-document token: {}'.format(eod_token))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment