Commit 05620ee4 authored by Neel Kant's avatar Neel Kant
Browse files

Merge branch 'ict-merge' into 'master'

ICT code

See merge request ADLR/megatron-lm!90
parents c20f4d48 5247f24c
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
num_batches = 0
def general_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
# simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
model = ICTBertModel(
ict_head_size=args.ict_head_size,
num_tokentypes=2,
parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model
def model_provider():
return general_model_provider(False, False)
def get_group_world_size_rank():
group = mpu.get_data_parallel_group()
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
return group, rank, world_size
class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=0).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
group, rank, world_size = get_group_world_size_rank()
assert grad_output.shape[0] % world_size == 0
dim_size = grad_output.shape[0] // world_size
output_list = torch.split(grad_output, dim_size, dim=0)
# get chunk from this rank
output = output_list[rank].contiguous()
return output
def get_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long()
block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_data'].long()
return query_tokens, query_pad_mask,\
block_tokens, block_pad_mask, block_indices
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
# scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
return retrieval_loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
args = get_args()
print_rank_0('> building train, validation, and test datasets '
'for BERT ICT...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
max_seq_length=args.seq_length,
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -21,8 +21,8 @@ import time
import torch
from megatron import get_args
from megatron import mpu
from megatron import print_rank_0
from megatron import mpu
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......
......@@ -18,9 +18,9 @@
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
......
......@@ -16,8 +16,8 @@
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -16,8 +16,8 @@
"""Race."""
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
......@@ -22,8 +22,8 @@ import numpy as np
import torch
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import get_tokenizer
from .detokenizer import get_detokenizer
......
......@@ -20,9 +20,9 @@ import math
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.model import GPT2Model
from megatron.training import get_model
......
......@@ -21,8 +21,8 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import get_args
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPT2Model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment