Commit 32bb4edc authored by Neel Kant's avatar Neel Kant
Browse files

Prune changes to only be related to ICT

parent 674814a5
import os
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint, get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.model import BertModel from megatron.model import BertModel
from megatron.model.utils import get_linear_layer, init_method_normal from megatron.model.utils import get_linear_layer, init_method_normal
...@@ -12,294 +13,6 @@ from megatron.utils import report_memory ...@@ -12,294 +13,6 @@ from megatron.utils import report_memory
from megatron import mpu from megatron import mpu
class REALMAnswerSpanModel(MegatronModule):
def __init__(self, realm_model, mlp_hidden_size=64):
super(REALMAnswerSpanModel, self).__init__()
self.realm_model = realm_model
self.mlp_hidden_size = mlp_hidden_size
args = get_args()
init_method = init_method_normal(args.init_method_std)
self.fc1 = get_linear_layer(2 * args.hidden_size, self.mlp_hidden_size, init_method)
self._fc1_key = 'fc1'
self.fc2 = get_linear_layer(self.mlp_hidden_size, 1, init_method)
self._fc2_key = 'fc2'
max_length = 10
self.start_ends = []
for length in range(max_length):
self.start_ends.extend([(i, i + length) for i in range(288 - length)])
def forward(self, question_tokens, question_attention_mask, answer_tokens, answer_token_lengths):
lm_logits, block_probs, topk_block_tokens = self.realm_model(
question_tokens, question_attention_mask, query_block_indices=None, return_topk_block_tokens=True)
batch_span_reps, batch_loss_masks = [], []
# go through batch one-by-one
for i in range(len(answer_token_lengths)):
answer_length = answer_token_lengths[i]
answer_span_tokens = answer_tokens[i][:answer_length]
span_reps, loss_masks = [], []
# go through the top k for the batch item
for logits, block_tokens in zip(lm_logits[i], topk_block_tokens[i]):
block_logits = logits[len(logits) / 2:]
span_starts = range(len(block_tokens) - (answer_length - 1))
# record the start, end indices of spans which match the answer
matching_indices = set([
(idx, idx + answer_length - 1) for idx in span_starts
if np.array_equal(block_tokens[idx:idx + answer_length], answer_span_tokens)
])
# create a mask for computing the loss on P(y | z, x)
# [num_spans]
loss_masks.append(torch.LongTensor([int(idx_pair in matching_indices) for idx_pair in self.start_ends]))
# get all of the candidate spans that need to be fed to MLP
# [num_spans x 2 * embed_size]
span_reps.append([torch.cat((block_logits[s], block_logits[e])) for (s, e) in self.start_ends])
# data for all k blocks for a single batch item
# [k x num_spans]
batch_loss_masks.append(torch.stack(loss_masks))
# [k x num_spans x 2 * embed_size]
batch_span_reps.append(torch.stack(span_reps))
# data for all batch items
# [batch_size x k x num_spans]
batch_loss_masks = torch.stack(batch_loss_masks)
batch_span_reps = torch.stack(batch_span_reps)
# [batch_size x k x num_spans]
batch_span_logits = self.fc2(self.fc1(batch_span_reps)).squeeze()
return batch_span_logits, batch_loss_masks, block_probs
# block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
# lm_logits = torch.sum(lm_logits * block_probs, dim=1)
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=2,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
load_checkpoint(self.lm_model, optimizer=None, lr_scheduler=None)
self._lm_key = 'realm_lm'
self.retriever = retriever
self.top_k = self.retriever.top_k
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset = self.retriever.ict_dataset
det_tokens = detach(tokens)[0].tolist()
det_attention = detach(attention_mask)[0].tolist()
# print("\nTokens: ", det_tokens, '\n', flush=True)
# print("\nAttention: ", det_attention, '\n', flush=True)
# print("pad id: ", dset.pad_id, flush=True)
assert bool(0 in det_attention) == bool(dset.pad_id in det_tokens)
if 0 in det_attention:
idx_padid = det_tokens.index(dset.pad_id)
idx_attn = det_attention.index(0)
assert idx_padid == idx_attn, (idx_padid, idx_attn)
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
tokens, attention_mask, query_block_indices=query_block_indices, include_null_doc=True)
# print("Top k block shape: ", topk_block_tokens.shape, flush=True)
batch_size = tokens.shape[0]
# create a copy in case it needs to be returned
ret_topk_block_tokens = np.array(topk_block_tokens)
seq_length = topk_block_tokens.shape[2]
long_tensor = torch.cuda.LongTensor
topk_block_tokens = long_tensor(topk_block_tokens).reshape(-1, seq_length)
topk_block_attention_mask = long_tensor(topk_block_attention_mask).reshape(-1, seq_length)
# print('Block token shape: ', topk_block_tokens.shape, flush=True)
# [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x embed_size x 1]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(2)
# print('Query logits shape: ', query_logits.shape, flush=True)
# [batch_size x k]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * k x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
# [batch_size * k x 2 * seq_length]
lm_input_batch_shape = (batch_size * self.top_k, 2 * seq_length)
all_tokens = torch.zeros(lm_input_batch_shape).long().cuda()
all_attention_mask = all_tokens.clone()
all_token_types = all_tokens.clone()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
query_lengths = torch.sum(attention_mask, axis=1)
# all blocks (including null ones) will have two SEP tokens
block_sep_indices = (topk_block_tokens == dset.sep_id).nonzero().reshape(batch_size * self.top_k, 2, 2)
# block body starts after the first SEP
block_starts = block_sep_indices[:, 0, 1] + 1
# block body ends after the second SEP
block_ends = block_sep_indices[:, 1, 1] + 1
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
for row_num in range(all_tokens.shape[0]):
q_len = query_lengths[row_num]
b_start = block_starts[row_num]
b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP
new_tokens_length = q_len + b_end - b_start
# splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, self.top_k, 2 * seq_length, -1)
if return_topk_block_tokens:
return lm_logits, block_probs, ret_topk_block_tokens
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
state_dict_[self._retriever_key] = self.retriever.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.lm_model.load_state_dict(state_dict[self._lm_key], strict)
self.retriever.load_state_dict(state_dict[self._retriever_key], strict)
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
self._ict_key = 'ict_model'
def reload_index(self):
args = get_args()
self.block_data = BlockData.load_from_file(args.block_data_path)
print("resetting index", flush=True)
self.hashed_index.reset_index()
self.hashed_index.add_block_embed_data(self.block_data)
def prep_query_text_for_retrieval(self, query_text):
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
return query_tokens, query_pad_mask
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
query_tokens, query_pad_mask = self.prep_query_text_for_retrieval(query_text)
topk_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(topk_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask, query_block_indices=None, include_null_doc=False):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
if hasattr(self.ict_model, 'module'):
true_model = self.ict_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
else:
true_model = self.ict_model
# print("true model: ", true_model, flush=True)
query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], []
# this will result in no candidate exclusion
if query_block_indices is None:
query_block_indices = [-1] * len(block_indices)
top_k_offset = int(include_null_doc)
for query_idx, indices in enumerate(block_indices):
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]]
topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - top_k_offset]]
if include_null_doc:
topk_block_data.append(self.ict_dataset.get_null_block())
topk_tokens, topk_pad_masks = zip(*topk_block_data)
all_topk_tokens.append(np.array(topk_tokens))
all_topk_pad_masks.append(np.array(topk_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_topk_tokens), np.array(all_topk_pad_masks)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._ict_key] = self.ict_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.ict_model.load_state_dict(state_dict[self._ict_key], strict)
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task.""" """Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
...@@ -341,10 +54,6 @@ class ICTBertModel(MegatronModule): ...@@ -341,10 +54,6 @@ class ICTBertModel(MegatronModule):
block_logits = self.embed_block(block_tokens, block_attention_mask) block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits return query_logits, block_logits
# [batch x embed] * [embed x batch]
# retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
# return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask): def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model""" """Embed a batch of tokens using the query model"""
if self.use_query_model: if self.use_query_model:
...@@ -391,10 +100,8 @@ class ICTBertModel(MegatronModule): ...@@ -391,10 +100,8 @@ class ICTBertModel(MegatronModule):
state_dict[self._block_key], strict=strict) state_dict[self._block_key], strict=strict)
def init_state_dict_from_bert(self): def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
args = get_args() args = get_args()
import os
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
tracker_filename = get_checkpoint_tracker_filename(args.bert_load) tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename): if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT load for ICT") raise FileNotFoundError("Could not find BERT load for ICT")
...@@ -412,8 +119,11 @@ class ICTBertModel(MegatronModule): ...@@ -412,8 +119,11 @@ class ICTBertModel(MegatronModule):
except BaseException: except BaseException:
raise ValueError("Could not load checkpoint") raise ValueError("Could not load checkpoint")
# load the LM state dict into each model
model_dict = state_dict['model']['language_model'] model_dict = state_dict['model']['language_model']
self.query_model.language_model.load_state_dict(model_dict) self.query_model.language_model.load_state_dict(model_dict)
self.block_model.language_model.load_state_dict(model_dict) self.block_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head'] query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict) self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
...@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype): ...@@ -78,7 +78,7 @@ def broadcast_data(keys, data, datatype):
members of the same model parallel group. members of the same model parallel group.
Arguments: Arguments:
keys: list of keys in the data dictionary to be broadcasted keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values. data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated datatype: torch data type of all tensors in data associated
with keys. with keys.
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Model and data parallel groups.""" """Model and data parallel groups."""
import datetime
import torch import torch
from .utils import ensure_divisibility from .utils import ensure_divisibility
...@@ -27,11 +26,6 @@ _MODEL_PARALLEL_GROUP = None ...@@ -27,11 +26,6 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_GLOO_COMM_GROUP = None
_TRAIN_GROUP = None
_INDEX_GROUP = None
_INDEX_READY = None
# These values enable us to change the mpu sizes on the fly. # These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE = None _MPU_WORLD_SIZE = None
_MPU_RANK = None _MPU_RANK = None
...@@ -102,13 +96,6 @@ def get_model_parallel_group(): ...@@ -102,13 +96,6 @@ def get_model_parallel_group():
return _MODEL_PARALLEL_GROUP return _MODEL_PARALLEL_GROUP
def set_model_parallel_group(group):
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, \
'model parallel group has already been initialized'
_MODEL_PARALLEL_GROUP = group
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \ assert _DATA_PARALLEL_GROUP is not None, \
...@@ -116,13 +103,6 @@ def get_data_parallel_group(): ...@@ -116,13 +103,6 @@ def get_data_parallel_group():
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
def set_data_parallel_group(group):
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group has already been initialized'
_DATA_PARALLEL_GROUP = group
def set_model_parallel_world_size(world_size): def set_model_parallel_world_size(world_size):
"""Set the model parallel size""" """Set the model parallel size"""
global _MPU_WORLD_SIZE global _MPU_WORLD_SIZE
...@@ -175,40 +155,3 @@ def destroy_model_parallel(): ...@@ -175,40 +155,3 @@ def destroy_model_parallel():
_MODEL_PARALLEL_GROUP = None _MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
def init_realm_groups(max_training_rank, world_size):
global _GLOO_COMM_GROUP
_GLOO_COMM_GROUP = torch.distributed.new_group(list(range(world_size)),
backend="gloo",
timeout=datetime.timedelta(0, 7200))
global _TRAIN_GROUP
_TRAIN_GROUP = torch.distributed.new_group(list(range(max_training_rank)))
global _INDEX_GROUP
_INDEX_GROUP = torch.distributed.new_group(list(range(max_training_rank, world_size)))
global _INDEX_READY
_INDEX_READY = torch.zeros(1)
def get_gloo_comm_group():
global _GLOO_COMM_GROUP
assert _GLOO_COMM_GROUP is not None
return _GLOO_COMM_GROUP
def get_train_group():
global _TRAIN_GROUP
assert _TRAIN_GROUP is not None
return _TRAIN_GROUP
def get_index_group():
global _INDEX_GROUP
assert _INDEX_GROUP is not None
return _INDEX_GROUP
def get_index_ready():
global _INDEX_READY
assert _INDEX_READY is not None
return _INDEX_READY
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
from datetime import datetime from datetime import datetime
import math import math
import sys import sys
import time
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
...@@ -37,19 +35,14 @@ from megatron.initialize import initialize_megatron ...@@ -37,19 +35,14 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
INDEX_READY = None
def pretrain(train_valid_test_dataset_provider, model_provider, def pretrain(train_valid_test_dataset_provider, model_provider,
forward_step_func, extra_args_provider=None, args_defaults={}, forward_step_func, extra_args_provider=None, args_defaults={}):
initializer_func=None):
"""Main training program. """Main training program.
This function will run the followings in the order provided: This function will run the followings in the order provided:
...@@ -75,14 +68,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider, ...@@ -75,14 +68,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
""" """
# Initalize and get arguments, timers, and Tensorboard writer. # Initalize and get arguments, timers, and Tensorboard writer.
if initializer_func is None: initialize_megatron(extra_args_provider=extra_args_provider,
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults)
args_defaults=args_defaults)
else:
initializer_func(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
global INDEX_READY
INDEX_READY = get_index_ready()
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -232,10 +219,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -232,10 +219,8 @@ def setup_model_and_optimizer(model_provider_func):
args.iteration = 0 args.iteration = 0
if args.iteration == 0 and isinstance(model.module.module, ICTBertModel): if args.iteration == 0 and isinstance(model.module.module, ICTBertModel):
print("Yes, located ICT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
model.module.module.init_state_dict_from_bert() model.module.module.init_state_dict_from_bert()
elif args.iteration == 0:
print("Ooops", flush=True)
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
...@@ -244,15 +229,12 @@ def backward_step(optimizer, model, loss): ...@@ -244,15 +229,12 @@ def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
# torch.cuda.synchronize()
# Backward pass. # Backward pass.
# optimizer.zero_grad(set_grads_to_None=True) optimizer.zero_grad(set_grads_to_None=True)
if args.fp16: if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
optimizer.backward(loss, update_master_grads=False) optimizer.backward(loss, update_master_grads=False)
else: else:
optimizer.zero_grad()
loss.backward() loss.backward()
# All-reduce if needed. # All-reduce if needed.
...@@ -261,9 +243,11 @@ def backward_step(optimizer, model, loss): ...@@ -261,9 +243,11 @@ def backward_step(optimizer, model, loss):
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop() timers('allreduce').stop()
# Update master gradients. # Update master gradients.
if args.fp16: if args.fp16:
optimizer.update_master_grads() optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0: if args.clip_grad > 0:
if not args.fp16: if not args.fp16:
...@@ -283,12 +267,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -283,12 +267,11 @@ def train_step(forward_step_func, data_iterator,
loss, loss_reduced = forward_step_func(data_iterator, model) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
# Calculate gradients, reduce across processes, and clip.
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
...@@ -383,54 +366,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -383,54 +366,12 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
global INDEX_READY
print('>>> Starting train()', flush=True)
# start off by posting a receive call which will be answered.
# synchronize for start
if args.max_training_rank is not None:
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters: while iteration < args.train_iters:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 500 and not recv_handle.is_completed():
time.sleep(5)
continue
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print("> Saving model and reloading index", flush=True)
if args.rank == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
true_model.retriever.reload_index()
if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
torch.cuda.synchronize()
# send handle
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
elif iteration < 20:
print("moving right along", flush=True)
# report_memory("iteration {}".format(iteration))
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
skipped_iters += skipped_iter skipped_iters += skipped_iter
iteration += 1 iteration += 1
...@@ -463,7 +404,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -463,7 +404,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier(get_data_parallel_group()) torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at ' print_rank_0('rank: {} | time: {} | exiting the program at '
......
...@@ -25,7 +25,6 @@ from megatron import mpu ...@@ -25,7 +25,6 @@ from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.mpu.initialize import get_data_parallel_group
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -33,13 +32,8 @@ def reduce_losses(losses): ...@@ -33,13 +32,8 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses, group=get_data_parallel_group()) torch.distributed.all_reduce(reduced_losses)
args = get_args() reduced_losses = reduced_losses / torch.distributed.get_world_size()
if args.max_training_rank is not None:
num_trainers = args.max_training_rank
else:
num_trainers = torch.distributed.get_world_size()
reduced_losses = reduced_losses / num_trainers
return reduced_losses return reduced_losses
...@@ -84,7 +78,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -84,7 +78,7 @@ def check_adlr_autoresume_termination(iteration, model,
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
torch.distributed.barrier(get_data_parallel_group()) torch.distributed.barrier()
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
...@@ -36,6 +36,7 @@ def model_provider(only_query_model=False, only_block_model=False): ...@@ -36,6 +36,7 @@ def model_provider(only_query_model=False, only_block_model=False):
args = get_args() args = get_args()
print_rank_0('building BERT models ...') print_rank_0('building BERT models ...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel( model = ICTBertModel(
ict_head_size=128, ict_head_size=128,
num_tokentypes=2, num_tokentypes=2,
...@@ -93,19 +94,16 @@ def forward_step(data_iterator, model): ...@@ -93,19 +94,16 @@ def forward_step(data_iterator, model):
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda() all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
all_block_logits = all_query_logits.clone().cuda() all_block_logits = all_query_logits.clone().cuda()
# record this processes' data and then merge with other processes below
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist.all_reduce(all_query_logits) dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits) dist.all_reduce(all_block_logits)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
# scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_acc(k): def topk_acc(k):
...@@ -113,11 +111,6 @@ def forward_step(data_iterator, model): ...@@ -113,11 +111,6 @@ def forward_step(data_iterator, model):
top_accs = [topk_acc(k) for k in [1, 8, 20, 100]] top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
reduced_losses = reduce_losses([retrieval_loss, *top_accs]) reduced_losses = reduce_losses([retrieval_loss, *top_accs])
stats_dict = { stats_dict = {
'retrieval loss': reduced_losses[0], 'retrieval loss': reduced_losses[0],
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.nn.functional as F
from indexer import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex, FaissMIPSIndex
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
from megatron.utils import reduce_losses, report_memory
from megatron import mpu
from indexer import initialize_and_run_async_megatron
num_batches = 0
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building REALM models ...')
try:
ict_model = load_ict_checkpoint(from_realm_chkpt=True)
except:
ict_model = load_ict_checkpoint(from_realm_chkpt=False)
ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128, use_gpu=args.faiss_use_gpu)
hashed_index.add_block_embed_data(all_block_data)
# top_k + 1 because we may need to exclude trivial candidate
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
model = REALMBertModel(retriever)
return model
def get_batch(data_iterator):
# Items and their type.
keys = ['tokens', 'labels', 'loss_mask', 'pad_mask', 'query_block_indices']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['tokens'].long()
labels = data_b['labels'].long()
loss_mask = data_b['loss_mask'].long()
pad_mask = data_b['pad_mask'].long()
query_block_indices = data_b['query_block_indices'].long()
return tokens, labels, loss_mask, pad_mask, query_block_indices
def get_qa_batch(data_iterator):
question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = next(data_iterator)
return question_tokens, question_attention_mask, answer_tokens, answer_token_lengths
def forward_step(data_iterator, model):
"""Forward step."""
timers = get_timers()
# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, pad_mask, query_block_indices = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
with torch.no_grad():
max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint(
get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs = torch.mean(block_probs[:, block_probs.shape[1] - 1])
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
labels.contiguous())
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
# torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0],
'max_ru': reduced_loss[1],
'top_ru': reduced_loss[2],
'avg_ru': reduced_loss[3],
'null_prob': reduced_loss[4]}
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits = lm_logits[:, :, :labels.shape[1], :]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits = lm_logits[:, -1, :, :]
null_block_loss_ = mpu.vocab_parallel_cross_entropy(null_block_lm_logits.contiguous().float(),
labels.contiguous())
null_block_loss = torch.sum(
null_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses = []
for block_num in range(lm_logits.shape[1] - 1):
retrieved_block_lm_logits = lm_logits[:, block_num, :, :]
retrieved_block_loss_ = mpu.vocab_parallel_cross_entropy(retrieved_block_lm_logits.contiguous().float(),
labels.contiguous())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss = torch.sum(
retrieved_block_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
retrieved_block_losses.append(retrieved_block_loss)
avg_retrieved_block_loss = torch.sum(torch.cuda.FloatTensor(retrieved_block_losses)) / (lm_logits.shape[1] - 1)
max_retrieval_utility = null_block_loss - min(retrieved_block_losses)
top_retrieval_utility = null_block_loss - retrieved_block_losses[0]
avg_retrieval_utility = null_block_loss - avg_retrieved_block_loss
return max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility
def qa_forward_step(data_iterator, model):
timers = get_timers()
# this dataset interface needs to be implemented
timers('batch generator').start()
question_tokens, question_attention_mask, answer_tokens, answer_token_lengths = get_qa_batch(data_iterator)
timers('batch generator').stop()
batch_span_logits, batch_loss_masks, block_probs = model(question_tokens, question_attention_mask,
answer_tokens, answer_token_lengths)
# [batch_size x k x num_spans]
block_probs = block_probs.unsqueeze(2).expand_as(batch_span_logits)
batch_span_probs = F.softmax(batch_span_logits, dim=2)
reduced_block_span_probs = torch.sum(batch_span_probs * block_probs, dim=1)
qa_span_loss_ = -torch.log(reduced_block_span_probs)
qa_span_loss = torch.sum(
qa_span_loss_.view(-1) * batch_loss_masks
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='realm')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'},
initializer_func=initialize_and_run_async_megatron)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment