Commit 1dd51c0e authored by Neel Kant's avatar Neel Kant
Browse files

pretrain_bert_icy.py compiles and runs

parent b1efc33d
......@@ -19,7 +19,7 @@ import math
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset, InverseClozeDataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer
from . import corpora
......@@ -120,14 +120,20 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = split_ds(ds, split)
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
if 'ict' in ds_type.lower():
dstype = InverseClozeDataset
else:
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
ds = [GPT2Dataset(d, max_seq_len=seq_length) if d is not None else None for d in ds]
else:
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
dstype = bert_sentencepair_dataset
if 'ict' in ds_type.lower():
dstype = InverseClozeDataset
else:
dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length)
......
......@@ -924,6 +924,7 @@ class InverseClozeDataset(data.Dataset):
'context_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
}
return sample
def get_sentence_split_doc(self, idx):
......@@ -1015,4 +1016,5 @@ class InverseClozeDataset(data.Dataset):
num_pad = max(0, self.max_seq_len - len(tokens))
pad_mask = [0] * len(tokens) + [1] * num_pad
tokens += [self.tokenizer.get_command('pad').Id] * num_pad
token_types += [token_types[0]] * num_pad
return tokens, token_types, pad_mask
......@@ -96,7 +96,7 @@ def forward_step(data_iterator, model, args, timers):
context_tokens, 1 - context_pad_mask, context_types)
softmaxed = F.softmax(retrieval_scores, dim=0).float()
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.size()[0]))
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(softmaxed.size()[0]).cuda())
reduced_losses = reduce_losses([retrieval_loss])
......@@ -114,7 +114,7 @@ def get_train_val_test_data(args):
or args.data_loader == 'lazy'
or args.data_loader == 'tfrecords'):
data_config = configure_data()
ds_type = 'BERT'
ds_type = 'BERT_ict'
data_config.set_defaults(data_set_type=ds_type, transpose=False)
(train_data, val_data, test_data), tokenizer = data_config.apply(args)
num_tokens = vocab_size_with_padding(tokenizer.num_tokens, args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment