Commit ee2490d5 authored by Neel Kant's avatar Neel Kant
Browse files

Start creating REALMBertModel

parent 2d98cfbf
...@@ -89,6 +89,12 @@ class InverseClozeDataset(Dataset): ...@@ -89,6 +89,12 @@ class InverseClozeDataset(Dataset):
token_types = [0] * self.max_seq_length token_types = [0] * self.max_seq_length
return tokens, token_types, pad_mask return tokens, token_types, pad_mask
def get_block(self, start_idx, end_idx, doc_idx, block_idx):
block = [self.context_dataset[i] for i in range(start_idx, end_idx)]
title = list(self.titles_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[self.max_seq_length - (3 + len(title))]
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples): def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs: if not num_epochs:
if not max_num_samples: if not max_num_samples:
......
import numpy as np import numpy as np
import spacy
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron.data.bert_dataset import get_samples_mapping_ from megatron.data.bert_dataset import get_samples_mapping_
from megatron.data.dataset_utils import build_simple_training_sample from megatron.data.dataset_utils import build_simple_training_sample
qa_nlp = spacy.load('en_core_web_lg')
class RealmDataset(Dataset): class RealmDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed):
...@@ -58,3 +68,14 @@ class RealmDataset(Dataset): ...@@ -58,3 +68,14 @@ class RealmDataset(Dataset):
self.mask_id, self.pad_id, self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng) self.masked_lm_prob, np_rng)
def spacy_ner(block_text):
candidates = {}
block = qa_nlp(block_text)
starts = []
answers = []
for ent in block.ents:
starts.append(int(ent.start_char))
answers.append(str(ent.text))
candidates['starts'] = starts
candidates['answers'] = answers
...@@ -14,6 +14,6 @@ ...@@ -14,6 +14,6 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .bert_model import BertModel, ICTBertModel from .bert_model import BertModel, ICTBertModel, REALMBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
...@@ -214,8 +214,49 @@ class BertModel(MegatronModule): ...@@ -214,8 +214,49 @@ class BertModel(MegatronModule):
state_dict[self._ict_head_key], strict=strict) state_dict[self._ict_head_key], strict=strict)
# REALMBertModel is just BertModel without binary head. class REALMBertModel(MegatronModule):
# needs a different kind of dataset though def __init__(self, ict_model_path, block_hash_data_path):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=2,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
self._lm_key = 'realm_lm'
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_hash_data = block_hash_data
def forward(self, tokens, attention_mask, token_types):
# [batch_size x embed_size]
query_logits = self.ict_model.embed_query(tokens, attention_mask, token_types)
hash_matrix_pos = self.hash_data['matrix']
# [batch_size, num_buckets / 2]
query_hash_pos = torch.matmul(query_logits, hash_matrix_pos)
query_hash_full = torch.cat((query_hash_pos, -query_hash_pos), axis=1)
# [batch_size]
query_hashes = torch.argmax(query_hash_full, axis=1)
batch_block_embeds = []
for hash in query_hashes:
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks = self.hash_data[hash]
block_indices = bucket_blocks[:, 3]
# [bucket_pop, embed_size]
block_embeds = [self.block_data[idx] for idx in block_indices]
# will become [batch_size, bucket_pop, embed_size]
# will require padding to do tensor multiplication
batch_block_embeds.append(block_embeds)
batch_block_embeds = np.array(batch_block_embeds)
retrieval_scores = query_logits.matmul(torch.transpose(batch_block_embeds, 0, 1))
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
...@@ -249,6 +290,11 @@ class ICTBertModel(MegatronModule): ...@@ -249,6 +290,11 @@ class ICTBertModel(MegatronModule):
return query_logits, block_logits return query_logits, block_logits
def embed_query(self, query_tokens, query_attention_mask, query_types):
query_ict_logits, _ = self.question_model.forward(query_tokens, 1 - query_attention_mask, query_types)
return query_ict_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel, REALMBertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
num_batches = 0
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building BERT models ...')
realm_model = REALMBertModel(args.ict_model_path,
args.block_hash_data_path)
return ict_model
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_types', 'query_pad_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_types = data_b['query_types'].long()
query_pad_mask = data_b['query_pad_mask'].long()
return query_tokens, query_types, query_pad_mask
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
query_tokens, query_types, query_pad_mask = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, query_types,
block_tokens, block_pad_mask, block_types).float()
# [batch x h] * [h x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
batch_size = softmaxed.shape[0]
top1_acc = torch.cuda.FloatTensor([sum([int(top5_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
top5_acc = torch.cuda.FloatTensor([sum([int(i in top5_indices[i]) for i in range(batch_size)]) / batch_size])
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
return retrieval_loss, {'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]}
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment