Commit 8ec5d678 authored by hepj987's avatar hepj987
Browse files

GPT2 base on megatron-deepspeed

parents
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE dataset."""
from abc import ABC
from abc import abstractmethod
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_text
class GLUEAbstractDataset(ABC, Dataset):
"""GLUE base dataset class."""
def __init__(self, task_name, dataset_name, datapaths,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ids, types, paddings = build_tokens_types_paddings_from_text(
raw_sample['text_a'], raw_sample['text_b'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ids, types, paddings,
raw_sample['label'], raw_sample['uid'])
return sample
@abstractmethod
def process_samples_from_single_path(self, datapath):
"""Abstract method that takes a single path / filename and
returns a list of dataset samples, each sample being a dict of
{'text_a': string, 'text_b': string, 'label': int, 'uid': int}
"""
pass
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
def glue_classification(num_classes, Dataset,
name_from_datapath_func):
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = Dataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building classification model for {} ...'.format(
args.task))
model = Classification(num_classes=num_classes, num_tokentypes=2,
pre_process=pre_process, post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = name_from_datapath_func(datapath)
return Dataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
def main():
args = get_args()
if args.task == 'MNLI':
num_classes = 3
from tasks.glue.mnli import MNLIDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('MNLI')[-1].strip(
'.tsv').strip('/').replace('_', '-')
elif args.task == 'QQP':
num_classes = 2
from tasks.glue.qqp import QQPDataset as Dataset
def name_from_datapath(datapath):
return datapath.split('QQP')[-1].strip(
'.tsv').strip('/').replace('_', '-')
else:
raise NotImplementedError('GLUE task {} is not implemented.'.format(
args.task))
glue_classification(num_classes, Dataset, name_from_datapath)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MNLI dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = {'contradiction': 0, 'entailment': 1, 'neutral': 2}
class MNLIDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label='contradiction'):
self.test_label = test_label
super().__init__('MNLI', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 10:
is_test = True
print_rank_0(
' reading {}, {} and {} columns and setting '
'labels to {}'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), self.test_label))
else:
print_rank_0(' reading {} , {}, {}, and {} columns '
'...'.format(
row[0].strip(), row[8].strip(),
row[9].strip(), row[-1].strip()))
continue
text_a = clean_text(row[8].strip())
text_b = clean_text(row[9].strip())
unique_id = int(row[0].strip())
label = row[-1].strip()
if is_test:
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
assert label in LABELS
assert unique_id >= 0
sample = {'text_a': text_a,
'text_b': text_b,
'label': LABELS[label],
'uid': unique_id}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""QQP dataset."""
from megatron import print_rank_0
from tasks.data_utils import clean_text
from .data import GLUEAbstractDataset
LABELS = [0, 1]
class QQPDataset(GLUEAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length,
test_label=0):
self.test_label = test_label
super().__init__('QQP', name, datapaths,
tokenizer, max_seq_length)
def process_samples_from_single_path(self, filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
first = True
is_test = False
with open(filename, 'r') as f:
for line in f:
row = line.strip().split('\t')
if first:
first = False
if len(row) == 3:
is_test = True
print_rank_0(' reading {}, {}, and {} columns and '
'setting labels to {}'.format(
row[0].strip(), row[1].strip(),
row[2].strip(), self.test_label))
else:
assert len(row) == 6
print_rank_0(' reading {}, {}, {}, and {} columns'
' ...'.format(
row[0].strip(), row[3].strip(),
row[4].strip(), row[5].strip()))
continue
if is_test:
assert len(row) == 3, 'expected length 3: {}'.format(row)
uid = int(row[0].strip())
text_a = clean_text(row[1].strip())
text_b = clean_text(row[2].strip())
label = self.test_label
assert len(text_a) > 0
assert len(text_b) > 0
else:
if len(row) == 6:
uid = int(row[0].strip())
text_a = clean_text(row[3].strip())
text_b = clean_text(row[4].strip())
label = int(row[5].strip())
else:
print_rank_0('***WARNING*** index error, '
'skipping: {}'.format(row))
continue
if len(text_a) == 0:
print_rank_0('***WARNING*** zero length a, '
'skipping: {}'.format(row))
continue
if len(text_b) == 0:
print_rank_0('***WARNING*** zero length b, '
'skipping: {}'.format(row))
continue
assert label in LABELS
assert uid >= 0
sample = {'uid': uid,
'text_a': text_a,
'text_b': text_b,
'label': label}
total += 1
samples.append(sample)
if total % 50000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument('--epochs', type=int, default=None,
help='Number of finetunning epochs. Zero results in '
'evaluation only.')
group.add_argument('--pretrained-checkpoint', type=str, default=None,
help='Pretrained checkpoint used for finetunning.')
group.add_argument('--keep-last', action='store_true',
help='Keep the last batch (maybe incomplete) in'
'the data loader')
group.add_argument('--train-data', nargs='+', default=None,
help='Whitespace separated paths or corpora names '
'for training.')
group.add_argument('--valid-data', nargs='*', default=None,
help='path(s) to the validation data.')
group.add_argument('--overlapping-eval', type=int, default=32,
help='Sliding window for overlapping evaluation.')
group.add_argument('--strict-lambada', action='store_true',
help='Use more difficult formulation of lambada.')
# Retriever args
group.add_argument('--qa-data-dev', type=str, default=None,
help='Path to the QA dataset dev file.')
group.add_argument('--qa-data-test', type=str, default=None,
help='Path to the QA dataset test file.')
# Faiss arguments for retriever
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--faiss-match', type=str, default='string', \
choices=['regex', 'string'], help="Answer matching '\
'logic type")
group.add_argument('--faiss-topk-retrievals', type=int, default=100,
help='Number of blocks to use as top-k during retrieval')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'RACE':
from race.finetune import main
elif args.task in ['MNLI', 'QQP']:
from glue.finetune import main
elif args.task in ['LAMBADA', 'WIKITEXT103']:
from zeroshot_gpt.evaluate import main
elif args.task in ['ICT-ZEROSHOT-NQ']:
from orqa.evaluate_orqa import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
from megatron import get_args
from tasks.orqa.evaluate_utils import ORQAEvaluator
def main():
"""
Main program
"""
args = get_args()
# Set up the model and evaluator
evaluator = ORQAEvaluator()
# Run evaluation
if args.qa_data_dev is not None:
evaluator.evaluate(args.qa_data_dev, "DEV")
if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from tasks.orqa.natural_questions.nq import get_nq_dataset
from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
from tasks.orqa.natural_questions.nq import process_nq_batch
from tasks.orqa.natural_questions.qa_utils import calculate_matches
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import get_model
class ORQAEvaluator(object):
def __init__(self):
args = get_args()
self.embedding_size = args.hidden_size
self.faiss_use_gpu = args.faiss_use_gpu
self.evidence_embedder_obj = None
self.evidence_dataset = None
self.mips_index = None
self.eval_dataset = None
# Get Evidence (Wikipedia) dataset
self.get_evidence_dataset()
# Load query encoder checkpoint
only_query_model = True
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(lambda: biencoder_model_provider(only_query_model=\
only_query_model, biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
assert len(self.model) == 1
self.model[0].eval()
# Load faiss indexer
self.faiss_wrapper()
def get_evidence_embedding(self):
# This will load the embedding from the embedding path
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
def get_evidence_dataset(self):
self.evidence_dataset = get_open_retrieval_wiki_dataset()
def faiss_wrapper(self):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args = get_args()
if args.local_rank == 0:
# Get evidence embeddings computed using context encoder
self.get_evidence_embedding()
assert self.evidence_embedder_obj is not None
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
embed_data=self.evidence_embedder_obj,
use_gpu=self.faiss_use_gpu)
# Wait for the FAISS index to be initialized in all the nodes
torch.distributed.barrier()
def generate_query_vectors(self, qa_data, split):
self.eval_dataset = get_nq_dataset(qa_data, split)
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
query_vectors = []
reference_list = []
for batch in dataloader:
# batch also has query_tokens and query_pad_data
query_tokens, query_mask, query_types, \
query_len, reference = process_nq_batch(batch)
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
with torch.no_grad():
query_logits = unwrapped_model.embed_text(
unwrapped_model.query_model, query_tokens,
query_mask, query_types)
reference_list.extend(reference)
query_vectors.extend(query_logits.split(1, dim=0))
if len(query_vectors) % 100 == 0:
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
query_tensor = torch.cat(query_vectors, dim=0)
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
assert query_tensor.size(0) == len(self.eval_dataset)
return query_tensor, reference_list
def evaluate(self, qa_data, split):
args = get_args()
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
split)
local_rank = args.local_rank
rank = torch.distributed.get_rank()
device_count = torch.cuda.device_count()
num_nodes = torch.distributed.get_world_size() // device_count
node_id = rank // device_count
for node in range(num_nodes):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
if node_id == node:
device_start_rank = start_rank
group = node_group
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
if local_rank == 0 and self.mips_index is not None:
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
distance, topkindex = self.mips_index.search_mips_index(
all_query_tensor, top_k=args.faiss_topk_retrievals,
reconstruct=False)
distance = torch.from_numpy(distance).cuda()
topkindex = torch.LongTensor(topkindex).cuda()
if local_rank != 0:
distance = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
topkindex = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
torch.distributed.broadcast(distance, src=device_start_rank, \
group=group)
torch.distributed.broadcast(topkindex, src=device_start_rank, \
group=group)
distance = torch.split(distance, len(query_tensor), dim=0)\
[local_rank]
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
[local_rank]
top_ids_and_scores = []
for darray, topkarray in zip(distance, topkindex):
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
passages = self.evidence_dataset.id2text
match_stats = calculate_matches(passages,
reference_list,
top_ids_and_scores,
workers_num=args.num_workers,
match_type=args.faiss_match)
top_k_hits = match_stats.top_k_hits
print_rank_0("{} SET RESULTS".format(split))
print_rank_0("topk-{} documents hits {}".format(
args.faiss_topk_retrievals, top_k_hits))
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
for i in args.retriever_report_topk_accuracies:
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
return
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().cuda()
query_mask = (batch['token_mask'] < 0.5).cuda()
query_types = batch['token_types'].long().cuda()
query_len = batch['seq_len'].long().cuda()
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import collections
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from megatron import logging
from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
logger = logging.get_logger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
'questions_doc_hits'])
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
answers: List[List[str]], closest_docs: List[Tuple[List[object],
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global dpr_all_documents
dpr_all_documents = all_docs
tok_opts = {}
tokenizer = SimpleTokenizer(**tok_opts)
processes = ProcessPool(
processes=workers_num,
)
logger.info('Matching answers in top docs...')
get_score_partial = partial(check_answer, match_type=match_type,
tokenizer=tokenizer)
questions_answers_docs = zip(answers, closest_docs)
scores = processes.map(get_score_partial, questions_answers_docs)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(closest_docs[0][0])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers, (doc_ids, doc_scores) = questions_answers_docs
global dpr_all_documents
hits = []
for i, doc_id in enumerate(doc_ids):
doc = dpr_all_documents[doc_id]
text = doc[0]
answer_found = False
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
if has_answer(answers, text, tokenizer, match_type):
answer_found = True
hits.append(answer_found)
return hits
def has_answer(answers, text, tokenizer, match_type) -> bool:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text = _normalize(text)
if match_type == 'string':
# Answer is a list of possible strings
text = tokenizer.tokenize(text).words(uncased=True)
for single_answer in answers:
single_answer = _normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
elif match_type == 'regex':
# Answer is a regex
for single_answer in answers:
single_answer = _normalize(single_answer)
if regex_match(text, single_answer):
return True
return False
def regex_match(text, pattern):
"""Test if a regex pattern is contained within a text."""
try:
pattern = re.compile(
pattern,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
)
except BaseException:
return False
return pattern.search(text) is not None
# function for the reader model answer validation
def exact_match_score(prediction, ground_truth):
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _normalize(text):
return unicodedata.normalize('NFD', text)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import copy
import regex
import spacy
from megatron import logging
logger = logging.get_logger(__name__)
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SpacyTokenizer(Tokenizer):
def __init__(self, **kwargs):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model = kwargs.get('model', 'en')
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
nlp_kwargs = {'parser': False}
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
nlp_kwargs['tagger'] = False
if 'ner' not in self.annotators:
nlp_kwargs['entity'] = False
self.nlp = spacy.load(model, **nlp_kwargs)
def tokenize(self, text):
# We don't treat new lines as tokens.
clean_text = text.replace('\n', ' ')
tokens = self.nlp.tokenizer(clean_text)
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
self.nlp.tagger(tokens)
if 'ner' in self.annotators:
self.nlp.entity(tokens)
data = []
for i in range(len(tokens)):
# Get whitespace
start_ws = tokens[i].idx
if i + 1 < len(tokens):
end_ws = tokens[i + 1].idx
else:
end_ws = tokens[i].idx + len(tokens[i].text)
data.append((
tokens[i].text,
text[start_ws: end_ws],
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
tokens[i].tag_,
tokens[i].lemma_,
tokens[i].ent_type_,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(data, self.annotators, opts={'non_ent': ''})
import glob
import json
import os
import time
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import clean_text
NUM_CHOICES = 4
MAX_QA_LENGTH = 128
class RaceDataset(Dataset):
def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
max_qa_length=MAX_QA_LENGTH):
self.dataset_name = dataset_name
print_rank_0(' > building RACE dataset for {}:'.format(
self.dataset_name))
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(process_single_datapath(datapath, tokenizer,
max_qa_length,
max_seq_length))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
print_rank_0(' > working on {}'.format(datapath))
start_time = time.time()
# Get list of files.
filenames = glob.glob(os.path.join(datapath, '*.txt'))
samples = []
num_docs = 0
num_questions = 0
num_samples = 0
# Load all the files
for filename in filenames:
with open(filename, 'r') as f:
for line in f:
data = json.loads(line)
num_docs += 1
context = data["article"]
questions = data["questions"]
choices = data["options"]
answers = data["answers"]
# Check the length.
assert len(questions) == len(answers)
assert len(questions) == len(choices)
# Context: clean up and convert to ids.
context = clean_text(context)
context_ids = tokenizer.tokenize(context)
# Loop over questions.
for qi, question in enumerate(questions):
num_questions += 1
# Label.
label = ord(answers[qi]) - ord("A")
assert label >= 0
assert label < NUM_CHOICES
assert len(choices[qi]) == NUM_CHOICES
# For each question, build num-choices samples.
ids_list = []
types_list = []
paddings_list = []
for ci in range(NUM_CHOICES):
choice = choices[qi][ci]
# Merge with choice.
if "_" in question:
qa = question.replace("_", choice)
else:
qa = " ".join([question, choice])
# Clean QA.
qa = clean_text(qa)
# Tokenize.
qa_ids = tokenizer.tokenize(qa)
# Trim if needed.
if len(qa_ids) > max_qa_length:
qa_ids = qa_ids[0:max_qa_length]
# Build the sample.
ids, types, paddings \
= build_tokens_types_paddings_from_ids(
qa_ids, context_ids, max_seq_length,
tokenizer.cls, tokenizer.sep, tokenizer.pad)
ids_list.append(ids)
types_list.append(types)
paddings_list.append(paddings)
# Convert to numpy and add to samples
samples.append(build_sample(ids_list, types_list,
paddings_list, label,
num_samples))
num_samples += 1
elapsed_time = time.time() - start_time
print_rank_0(' > processed {} document, {} questions, and {} samples'
' in {:.2f} seconds'.format(num_docs, num_questions,
num_samples, elapsed_time))
return samples
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron import mpu
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
def train_valid_datasets_provider():
"""Provide train and validation datasets."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = RaceDataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
model = MultipleChoice(num_tokentypes=2,
pre_process=pre_process,
post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
args = get_args()
tokenizer = get_tokenizer()
def single_dataset_provider(datapath):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
def main():
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
from megatron import get_args
from megatron import print_rank_0
from megatron.model.vit_model import VitModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
def classification():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
crop_size=args.img_dim,
)
return train_ds, valid_ds
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
return VitModel(num_classes=args.num_classes, finetune=True)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
import os
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = args.img_dim
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path = os.path.join(data_path[0], "val")
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
dataloader = build_data_loader(
dataset,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
print_rank_0(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
return metrics_func
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
model.eval()
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
# Run the model forward.
images, labels = process_batch(batch)
logits = model(images).contiguous().float()
# Add output predictions.
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
total += labels.size(0)
correct += corrects.sum().item()
model.train()
# Reduce.
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def _cross_entropy_forward_step(batch, model, input_tensor):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
assert input_tensor is None
# Get the batch.
timers("batch generator").start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
logits = model(images).contiguous().float()
# Cross-entropy loss.
loss = F.cross_entropy(logits, labels)
# Reduce loss for logging.
average_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": average_loss[0]}
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0("building train and validation dataloaders ...")
# Training dataset.
train_dataloader = build_data_loader(
train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(
valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
return train_dataloader, valid_dataloader
def _train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
):
"""Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers("interval-time").start()
for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter = train_step(
forward_step, batch, model, optimizer, lr_scheduler
)
iteration += 1
# Logging.
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
optimizer.param_groups[0]["lr"],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
)
# Autoresume
if args.adlr_autoresume and (
iteration % args.adlr_autoresume_interval == 0
):
check_adlr_autoresume_termination(
iteration, model, optimizer, lr_scheduler
)
# Checkpointing
if (
args.save
and args.save_interval
and iteration % args.save_interval == 0
):
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = "iteration {}".format(iteration)
evaluate_and_print_results(
prefix,
forward_step,
valid_dataloader,
model,
iteration,
False,
)
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None,
):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
# Train and validation data loaders.
timers("train/valid/test dataset/dataloder").start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset
)
timers("train/valid/test dataset/dataloder").stop()
# Build calback function.
timers("callback function").start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop()
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer").start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers("pretrained checkpoint").start()
if args.iteration == 0 and args.pretrained_checkpoint is not None:
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer.reload_model_params()
timers("pretrained checkpoint").stop()
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test dataset/dataloder",
"callback function",
"model and optimizer",
"pretrained checkpoint",
]
)
print_rank_0("training ...")
# Finetune the model.
if args.epochs > 0:
_train(
model,
optimizer,
lr_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0("evaluation only mode, setting epoch to -1")
end_of_epoch_callback(model, epoch=-1, output_predictions=True)
print_rank_0("done :-)")
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import os
import sys
sys.path.append(
os.path.abspath(
os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir),
os.path.pardir,
)
)
)
from megatron import get_args
from megatron.initialize import initialize_megatron
from classification import main
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument(
"--epochs",
type=int,
default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.",
)
group.add_argument(
"--pretrained-checkpoint",
type=str,
default=None,
help="Pretrained checkpoint used for finetunning.",
)
group.add_argument(
"--keep-last",
action="store_true",
help="Keep the last batch (maybe incomplete) in" "the data loader",
)
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
main()
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Zero-shot datasets."""
import json
import math
import numpy as np
import torch
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from .detokenizer import get_detokenizer
def build_dataset(task):
"""Helper function to select and build dataset."""
if task == 'LAMBADA':
return _build_lambada_dataset()
if task == 'WIKITEXT103':
return _build_wikitext103_dataset()
raise NotImplementedError('dataset for {} task is not '
'implemented.'.format(task))
class _LMDataset(torch.utils.data.Dataset):
def __init__(self, tokens, seq_len, pad_idx, num_original_tokens,
num_tokenized_tokens, overalapping_eval=None):
self.tokens = tokens
self.seq_len = seq_len
self.pad_idx = pad_idx
self.overalapping_eval = overalapping_eval
if self.overalapping_eval is None:
self.overalapping_eval = self.seq_len
self.overalapping_eval = max(1, self.overalapping_eval)
self.num_original_tokens = num_original_tokens
self.num_tokenized_tokens = num_tokenized_tokens
self.total_targets = len(self.tokens) - 1
# remove first sequence tokens
targets = max(self.total_targets - self.overalapping_eval, 0)
self.total_sequences = max(
math.ceil(targets / self.overalapping_eval) + 1, 1)
def __len__(self):
return self.total_sequences
def __getitem__(self, idx):
start_idx = idx * self.overalapping_eval
end_idx = start_idx + self.seq_len
tokens = self.tokens[start_idx:end_idx + 1]
num_tokens = len(tokens)
pad_mask = [1] * num_tokens
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:])
if self.overalapping_eval != self.seq_len and idx != 0:
pad_mask[:-self.overalapping_eval] *= 0
return {'text': np.array(tokens), 'pad_mask': pad_mask}
class _LambadaDataset(torch.utils.data.Dataset):
def __init__(self, path, pad_idx, tokenizer, seq_len, strict=False):
print_rank_0('> building lambada dataset from {} ...'.format(path))
self.seq_len = seq_len
self.pad_idx = pad_idx
self.tokenizer = tokenizer
self.strict = strict
self.tokens = []
self.labels = []
with open(path, 'r') as f:
for line in f.readlines():
text = json.loads(line)['text']
tokens, labels = self.get_tokens(text)
self.tokens.append(tokens)
self.labels.append(labels)
def get_tokens(self, text):
if not self.strict:
tokens = self.tokenizer.tokenize(text)
return tokens[:-1], [tokens[-1]]
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
last_token = self.tokenizer.tokenize(' ' + last_token)
return beginning_tokens, last_token
def __len__(self):
return len(self.tokens)
def __getitem__(self, idx):
tokens = self.tokens[idx]
num_tokens = len(tokens)
pad_mask = [0] * num_tokens
labels = self.labels[idx]
pad_mask += [1] * len(labels)
tokens = tokens + labels
num_tokens = len(tokens)
if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:])
return {'text': np.array(tokens), 'pad_mask': pad_mask}
def _build_lambada_dataset():
"""Build lambada dataset."""
args = get_args()
tokenizer = get_tokenizer()
assert len(args.valid_data) == 1
val_dataset = _LambadaDataset(args.valid_data[0], tokenizer.eod, tokenizer,
args.seq_length, args.strict_lambada)
print_rank_0(' > found {} samples.'.format(len(val_dataset)))
return val_dataset
def _build_wikitext103_dataset():
""""""
args = get_args()
tokenizer = get_tokenizer()
assert len(args.valid_data) == 1
with open(args.valid_data[0], "rb") as reader:
entire_data = reader.read().decode('utf-8')
num_original_tokens = len(entire_data.strip().split(" "))
entire_data = get_detokenizer(args.valid_data[0])(entire_data)
tokenized_data = tokenizer.tokenize(entire_data)
num_tokenized_tokens = len(tokenized_data)
val_dataset = _LMDataset(tokenized_data, args.seq_length, tokenizer.eod,
num_original_tokens, num_tokenized_tokens,
args.overlapping_eval)
print_rank_0(' > number of original tokens: {}, number of detokenized '
'tokens: {}'.format(num_original_tokens, num_tokenized_tokens))
return val_dataset
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detokenization."""
import re
def ptb_detokenizer(string):
string = string.replace(" '", "'")
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" n't", "n't")
string = string.replace(" N ", "1 ")
string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1")
return string
def wikitext_detokenizer(string):
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def lambada_detokenizer(string):
return string
_DETOKENIZERS = {
'ptb': ptb_detokenizer,
'wiki': wikitext_detokenizer,
'lambada': lambada_detokenizer,
}
def get_detokenizer(path):
for key in _DETOKENIZERS.keys():
if key in path:
return _DETOKENIZERS[key]
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT zero-shot evaluation."""
import math
import torch
from megatron import get_args
from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron import mpu
from megatron.checkpointing import load_checkpoint
from megatron.model.gpt_model import GPTModel
from megatron.training import get_model
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
from tasks.finetune_utils import build_data_loader
from .datasets import build_dataset
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model.distributed import DistributedDataParallel as LocalDDP
from megatron.model.module import Float16Module
def get_model_provider(eval_metric):
"""Based on evaluation metric set the parallel-output flag and
return the model provider."""
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
if eval_metric == 'loss':
parallel_output = True
elif eval_metric == 'accuracy':
parallel_output = False
else:
raise NotImplementedError('output type for {} evaluation metric '
'is not supported.'.format(eval_metric))
print_rank_0('building GPT model ...')
model = GPTModel(num_tokentypes=0, parallel_output=parallel_output,
pre_process=pre_process, post_process=post_process)
return model
return model_provider
def process_batch(batch):
"""Process batch and produce inputs for the model."""
args = get_args()
tokenizer = get_tokenizer()
loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
tokens_ = batch['text'].long().cuda().contiguous()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)
return tokens, labels, attention_mask, position_ids, loss_mask
def forward_step(batch, model, eval_metric):
"""Forward step."""
# Get the batch.
tokens, labels, attention_mask, position_ids, loss_mask = process_batch(
batch)
# Tell the model what our actual batch size will be
args = get_args()
args.micro_batch_size = len(labels)
input_tensor = recv_forward()
# Forward pass through the model.
unwrapped_model = unwrap_model(
model, (torchDDP, LocalDDP, Float16Module))
unwrapped_model.set_input_tensor(input_tensor)
output = model(tokens, position_ids, attention_mask)
send_forward(output)
if mpu.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
return loss
# For accuracy, return the number of correctly predicted samples.
if eval_metric == 'accuracy':
outputs = torch.argmax(output, -1)
correct = (outputs == labels).float()
correct[(1 - loss_mask).bool()] = 1
correct = correct.prod(-1)
return correct.sum()
raise NotImplementedError('forward method for evaluation metric {} '
'is not implemented.'.format(eval_metric))
return None
def evaluate(data_loader, model, eval_metric):
"""Evaluation."""
args = get_args()
# Turn on evaluation mode which disables dropout.
model.eval()
total_output = 0.0
with torch.no_grad():
# For all the batches in the dataset.
for iteration, batch in enumerate(data_loader):
if iteration % args.log_interval == 0:
print_rank_0('> working on iteration: {}'.format(iteration))
# Forward evaluation.
output = forward_step(batch, model, eval_metric)
# Reduce across processes.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group())
total_output += output
return total_output
def evaluate_and_print_results(task, data_loader, model, eval_metric):
"""Evaluate and print results on screen."""
# Evaluate and get results.
output = evaluate(data_loader, model, eval_metric)
string = ' validation results on {} | '.format(task)
if is_last_rank():
if eval_metric == 'loss':
num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
num_original_tokens = data_loader.dataset.num_original_tokens
val_loss = output / (num_tokenized_tokens - 1)
ppl = math.exp(min(20, val_loss))
token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
string += 'avg loss: {:.4E} | '.format(val_loss)
string += 'ppl: {:.4E} | '.format(ppl)
string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
string += 'token ratio: {} |'.format(token_ratio)
elif eval_metric == 'accuracy':
num_examples = len(data_loader.dataset)
acc = output / num_examples
string += 'number correct: {:.4E} | '.format(output)
string += 'total examples: {:.4E} | '.format(num_examples)
string += 'avg accuracy: {:.4E}'.format(acc)
else:
raise NotImplementedError('evaluation method for {} metric is not '
'implemented yet.'.format(eval_metric))
length = len(string) + 1
print('-' * length)
print(string)
print('-' * length)
def main():
"""Main program."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if args.task == 'LAMBADA':
eval_metric = 'accuracy'
elif args.task == 'WIKITEXT103':
eval_metric = 'loss'
else:
raise NotImplementedError('{} task is not implemented.'.format(
args.task))
# Set up model and load checkpoint.
model = get_model(get_model_provider(eval_metric))
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# Data stuff.
dataset = build_dataset(args.task)
dataloader = build_data_loader(dataset, args.micro_batch_size,
args.num_workers, drop_last=False)
# Run evaluation.
evaluate_and_print_results(args.task, dataloader, model, eval_metric)
print_rank_0('done :-)')
# Testing
This test suite heavily borrows from [HF Transformers](https://github.com/huggingface/transformers/), therefore you can refer to the its [testing docs](https://huggingface.co/transformers/testing.html) for in-depth details. In particular wrt writing new tests, as we have access a lot of helper classes and functions, so you can write tests very quickly and not need to reinvent the wheel.
The foundation is `pytest`, which allows you to write normal `pytest` tests, but we also use a lot of unit tests in particular via `TestCasePlus` which extends `unittest` and provides additional rich functionality.
## Running testing
```
make test
```
or:
```
pytest tests
```
Important: the first time you run this it can take some minutes to build all the Megatron cuda kernels and deepspeed kernels if you haven't pre-built the latter.
For various other options please see the doc mentioned at the very top.
You will want to have at least 1 gpu available, best 2 to run the tests.
## CI
The CI setup is documented [here](../.github/workflows/ci.md).
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment