Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
from megatron.model.biencoder_model import get_model_provider
from megatron.training import get_model
from tasks.orqa.unsupervised.nq import get_nq_dataset
from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
from tasks.orqa.unsupervised.nq import process_nq_batch
from tasks.orqa.unsupervised.qa_utils import calculate_matches
class ORQAEvaluator(object):
def __init__(self):
args = get_args()
self.embedding_size = args.hidden_size
self.faiss_use_gpu = args.faiss_use_gpu
self.evidence_embedder_obj = None
self.evidence_dataset = None
self.mips_index = None
self.eval_dataset = None
# Get Evidence (Wikipedia) dataset
self.get_evidence_dataset()
# Load query encoder checkpoint
only_query_model = True
if args.biencoder_shared_query_context_model:
only_query_model = False
model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_query_model=only_query_model)
assert len(self.model) == 1
self.model[0].eval()
# Load faiss indexer
self.faiss_wrapper()
def get_evidence_embedding(self):
# This will load the embedding from the embedding path
self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True)
def get_evidence_dataset(self):
self.evidence_dataset = get_open_retrieval_wiki_dataset()
def faiss_wrapper(self):
# Initialize FAISS wrapper on local rank = 0 as the evidence embeddings
# is distributed over all the GPUs in a node and FAISS is not
# thread-safe
args = get_args()
if args.local_rank == 0:
# Get evidence embeddings computed using context encoder
self.get_evidence_embedding()
assert self.evidence_embedder_obj is not None
self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size,
embed_data=self.evidence_embedder_obj,
use_gpu=self.faiss_use_gpu)
# Wait for the FAISS index to be initialized in all the nodes
torch.distributed.barrier()
def generate_query_vectors(self, qa_data, split):
self.eval_dataset = get_nq_dataset(qa_data, split)
dataloader = get_one_epoch_nq_dataloader(self.eval_dataset)
query_vectors = []
reference_list = []
for batch in dataloader:
# batch also has query_tokens and query_pad_data
query_tokens, query_mask, query_types, \
query_len, reference = process_nq_batch(batch)
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
with torch.no_grad():
query_logits = unwrapped_model.embed_text(
unwrapped_model.query_model, query_tokens,
query_mask, query_types)
reference_list.extend(reference)
query_vectors.extend(query_logits.split(1, dim=0))
if len(query_vectors) % 100 == 0:
print_rank_0('Encoded queries {}'.format(len(query_vectors)))
query_tensor = torch.cat(query_vectors, dim=0)
print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size()))
assert query_tensor.size(0) == len(self.eval_dataset)
return query_tensor, reference_list
def evaluate(self, qa_data, split):
args = get_args()
query_tensor, reference_list = self.generate_query_vectors(qa_data, \
split)
local_rank = args.local_rank
rank = torch.distributed.get_rank()
device_count = torch.cuda.device_count()
num_nodes = torch.distributed.get_world_size() // device_count
node_id = rank // device_count
for node in range(num_nodes):
start_rank = node * device_count
end_rank = (node + 1) * device_count
ranks_list = list(range(start_rank, end_rank))
node_group = torch.distributed.new_group(ranks=ranks_list)
if node_id == node:
device_start_rank = start_rank
group = node_group
input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(device_count)]
torch.distributed.all_gather(tensor_list, query_tensor, group=group)
if local_rank == 0 and self.mips_index is not None:
all_query_tensor = torch.cat(tensor_list, dim=0).contiguous()
distance, topkindex = self.mips_index.search_mips_index(
all_query_tensor, top_k=args.faiss_topk_retrievals,
reconstruct=False)
distance = torch.from_numpy(distance).cuda()
topkindex = torch.LongTensor(topkindex).cuda()
if local_rank != 0:
distance = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.float32).cuda()
topkindex = torch.empty(device_count * len(query_tensor), \
args.faiss_topk_retrievals, dtype=torch.int64).cuda()
torch.distributed.broadcast(distance, src=device_start_rank, \
group=group)
torch.distributed.broadcast(topkindex, src=device_start_rank, \
group=group)
distance = torch.split(distance, len(query_tensor), dim=0)\
[local_rank]
topkindex = torch.split(topkindex, len(query_tensor), dim=0)\
[local_rank]
top_ids_and_scores = []
for darray, topkarray in zip(distance, topkindex):
top_ids_and_scores.append((topkarray.tolist(), darray.tolist()))
passages = self.evidence_dataset.id2text
match_stats = calculate_matches(passages,
reference_list,
top_ids_and_scores,
workers_num=args.num_workers,
match_type=args.faiss_match)
top_k_hits = match_stats.top_k_hits
print_rank_0("{} SET RESULTS".format(split))
print_rank_0("topk-{} documents hits {}".format(
args.faiss_topk_retrievals, top_k_hits))
top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits]
print_rank_0("top-k documents hits accuracy {}".format(top_k_hits))
for i in args.retriever_report_topk_accuracies:
print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100))
return
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""ORQA dataset."""
import json
import random
from abc import ABC
from abc import abstractmethod
import numpy as np
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args
from megatron.data.biencoder_dataset_utils import make_attention_mask
def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
ctx_id_list, ctx_types_list = [], []
for context in ctx_list:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids
ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids,
max_seq_length, tokenizer.cls,
tokenizer.sep, tokenizer.pad)
ctx_id_list.append(ctx_ids)
ctx_types_list.append(ctx_types)
return ctx_id_list, ctx_types_list
def build_tokens_types_paddings_from_text(query, context,
tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids = tokenizer.tokenize(query)
query_ids, query_types, query_pad_mask = \
build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \
tokenizer.cls, tokenizer.sep, tokenizer.pad)
# Appending the title of the context at front
extended_ctx_ids = None
if context is not None:
title_ids = tokenizer.tokenize(context['title'])
ctx_ids = tokenizer.tokenize(context['text'])
extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids
ctx_ids, ctx_types, ctx_pad_mask = \
build_tokens_types_paddings_from_ids(extended_ctx_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return query_ids, query_types, query_pad_mask, \
ctx_ids, ctx_types, ctx_pad_mask
# Similar code tasks/data_utils with some changes
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids = np.array(query_ids, dtype=np.int64)
query_types = np.array(query_types, dtype=np.int64)
query_mask = make_attention_mask(query_ids, query_ids)
ctx_ids = np.array(ctx_ids, dtype=np.int64)
ctx_types = np.array(ctx_types, dtype=np.int64)
ctx_mask = make_attention_mask(ctx_ids, ctx_ids)
sample = ({
'query': query_ids,
'query_mask': query_mask,
'query_types': query_types,
'query_pad_mask': query_pad_mask,
'context': ctx_ids,
'context_mask': ctx_mask,
'context_types': ctx_types,
'context_pad_mask': ctx_pad_mask,
'reference': answers
})
if include_neg:
neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64)
neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64)
neg_ctx_mask = np.array([make_attention_mask(ids, ids) \
for ids in neg_ctx_ids], dtype=np.int64)
sample['neg_context'] = neg_ctx_ids
sample['neg_context_types'] = neg_ctx_id_types
sample['neg_context_mask'] = neg_ctx_mask
return sample
class OpenRetrievalAbstractDataset(ABC, Dataset):
"""Open Retrieval base dataset class."""
def __init__(self, task_name, dataset_name, datapaths, tokenizer, \
max_seq_length, evaluate=False):
# Store inputs.
args = get_args()
self.evaluate = evaluate
self.val_av_rank_hard_neg = args.val_av_rank_hard_neg
self.val_av_rank_other_neg = args.val_av_rank_other_neg
self.train_with_neg = args.train_with_neg
self.train_hard_neg = args.train_hard_neg
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(self.process_samples_from_single_path(datapath))
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \
ctx_pad_mask = build_tokens_types_paddings_from_text( \
raw_sample['question'], raw_sample['pos_context'], \
self.tokenizer, self.max_seq_length)
if self.evaluate:
neg_ctx_list = \
raw_sample['negative_context'][:self.val_av_rank_other_neg] + \
raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list, \
self.tokenizer, self.max_seq_length)
elif self.train_with_neg:
hard_negative_ctx = raw_sample['hard_negative_context']
negative_ctx = raw_sample['negative_context']
if True: # TODO: fix this or remove this condition
random.shuffle(hard_negative_ctx)
random.shuffle(negative_ctx)
neg_ctx_list = hard_negative_ctx[:self.train_hard_neg]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if len(neg_ctx_list) < self.train_hard_neg:
neg_ctx_list += negative_ctx[:self.train_hard_neg - \
len(neg_ctx_list)]
neg_ctx_id_list, neg_ctx_types_list = \
build_token_types_from_context_list(neg_ctx_list,
self.tokenizer, self.max_seq_length)
else:
neg_ctx_id_list = None
neg_ctx_types_list = None
sample = build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask,
raw_sample['answers'],
neg_ctx_id_list, neg_ctx_types_list,
include_neg=self.evaluate or self.train_with_neg)
return sample
@staticmethod
@abstractmethod
def process_samples_from_single_path(filename):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def normalize_question(question):
if question[-1] == '?':
question = question[:-1]
return question
# The following class reads the datasets for training retriever as
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
class NQSupervisedDataset(OpenRetrievalAbstractDataset):
def __init__(self, name, datapaths, tokenizer, max_seq_length, \
evaluate=False):
super().__init__('natural_questions_ret',
name,
datapaths,
tokenizer,
max_seq_length,
evaluate=evaluate)
@staticmethod
def process_samples_from_single_path(filename):
""""Implement abstract method."""
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r', encoding="utf-8") as f:
data = json.load(f)
for row in data:
question = normalize_question(row['question'])
pos_context = row['positive_ctxs'][0]
# Hard Negative Contexts
if len(row['hard_negative_ctxs']) > 0:
hard_neg_context = row['hard_negative_ctxs']
else:
hard_neg_context = []
# Negative Contexts
if len(row['negative_ctxs']) > 0:
neg_context = row['negative_ctxs']
else:
neg_context = []
answers = row['answers']
sample = {'question': question,
'pos_context': pos_context,
'hard_negative_context': hard_neg_context,
'negative_context': neg_context,
'answers': answers}
total += 1
samples.append(sample)
if total % 5000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Evaluation utilities."""
from collections import OrderedDict
import math
import numpy as np
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from megatron import get_args, print_rank_0
from megatron.core import mpu
from megatron.utils import average_losses_across_data_parallel_group
from tasks.finetune_utils import build_data_loader
def task_collate_fn(batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
tensorized['query_pad_mask'] = \
torch.LongTensor(tensorized['query_pad_mask'])
tensorized['context'] = torch.LongTensor(tensorized['context'])
tensorized['context_mask'] = \
torch.LongTensor(tensorized['context_mask'])
tensorized['context_types'] = \
torch.LongTensor(tensorized['context_types'])
tensorized['context_pad_mask'] = \
torch.LongTensor(tensorized['context_pad_mask'])
if 'neg_context' in tensorized:
tensorized['neg_context'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context']))
tensorized['neg_context_mask'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
tensorized['neg_context_types'] = \
torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
return tensorized
def process_batch(batch):
"""Process batch and produce inputs for the model."""
query_tokens = batch['query'].long().cuda()
query_mask = (batch['query_mask'] < 0.5).cuda()
query_types = batch['query_types'].long().cuda()
query_pad_mask = batch['query_pad_mask'].long().cuda()
context_tokens = batch['context'].long().cuda()
context_mask = (batch['context_mask'] < 0.5).cuda()
context_types = batch['context_types'].long().cuda()
context_pad_mask = batch['context_pad_mask'].long().cuda()
if 'neg_context' in batch:
neg_context_tokens = batch['neg_context'].long().cuda()
neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
neg_context_types = batch['neg_context_types'].long().cuda()
else:
neg_context_tokens = None
neg_context_mask = None
neg_context_types = None
reference = batch['reference']
return query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
"""Provide function that calculates accuracies."""
args = get_args()
print_rank_0("accuracy_func_provider is CALLED")
# Build dataloaders
datapath = args.valid_data
dataset = single_dataset_provider(datapath)
drop_last = False
if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
drop_last = True
print_rank_0(datapath)
print_rank_0(rank0sampler)
dataloader = build_data_loader(dataset,
args.eval_micro_batch_size,
num_workers=args.num_workers,
drop_last=drop_last,
task_collate_fn=task_collate_fn)
dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False):
print_rank_0('calculating metrics by accuracy func in ORQA...')
if output_predictions:
assert rank0sampler
names = 'predictions'
name, dataloader = dataloaders
if args.task == "RET-FINETUNE-NQ":
start_time = time.time()
output = retrieval_loss(model, dataloader)
stats_dict, total = output
format_string = ""
for k, v in stats_dict.items():
format_string += "|{} = {:.2f}".format(k, v / total)
print_rank_0("epoch:{}{}".format(epoch, format_string))
print_rank_0("taken time to calcuate metrics {:.3f}".format(\
time.time() - start_time))
else:
raise AssertionError("{} Task not supported".format(args.task))
return metrics_func
def retrieval_loss(model, dataloader):
args = get_args()
total = 0
topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
args.retriever_report_topk_accuracies}
stats_dict = dict(rank=0, **topk_stats_dict)
assert len(model) == 1
unwrapped_model = model[0]
unwrapped_model.eval()
with torch.no_grad():
# For all the batches in the dataset.
for batch in dataloader:
# Run the model forward.
query_tokens, query_mask, query_types, _, \
context_tokens, context_mask, context_types, _, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch)
query_logits, context_logits = unwrapped_model(query_tokens,
query_mask, query_types,
torch.cat([context_tokens, neg_context_tokens]),
torch.cat([context_mask, neg_context_mask]),
torch.cat([context_types, neg_context_types]))
retrieval_scores = torch.matmul(query_logits,
torch.transpose(context_logits, 0, 1))
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / \
math.sqrt(args.hidden_size)
local_batch_size = query_logits.shape[0]
labels = torch.arange(local_batch_size).long().cuda()
softmax_scores = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1],
sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor(
[sum([int(labels[i] in sorted_indices[i, :k]) for i in \
range(local_batch_size)])])
def get_rank():
return torch.cuda.FloatTensor(
[sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
for i in range(local_batch_size)])])
topk_accs = [topk_accuracy(k) for k in \
args.retriever_report_topk_accuracies]
rank = get_rank()
losses = average_losses_across_data_parallel_group([rank, \
*topk_accs])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])}
temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
for k in stats_dict.keys():
stats_dict[k] += temp_stats_dict[k]
total += local_batch_size
unwrapped_model.train()
return stats_dict, total
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""ORQA finetuning/evaluation."""
from functools import partial
import sys
import math
import torch
import torch.nn.functional as F
from megatron import get_args, get_timers, get_tokenizer, print_rank_0
from megatron.core import mpu
from megatron.indexer import IndexBuilder
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.utils import average_losses_across_data_parallel_group
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):
# gather the size of the first dimension of the tensor from all ranks
current_length = input_.size()[0]
first_dim = torch.tensor([[current_length]],
device=torch.cuda.current_device())
input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
input_list[rank].copy_(first_dim)
torch.distributed.all_gather(input_list, first_dim, group=group)
all_input_list = torch.cat(input_list, dim=0).contiguous()
max_length = torch.max(all_input_list)
# if the size are different than the max, extend the tensor
# accordingly
if max_length > current_length:
padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding)
return input_
def orqa(Dataset):
def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
tokenizer = get_tokenizer()
# Get the batch.
timers('batch generator', log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
reference = process_batch(batch_)
timers('batch generator').stop()
local_batch_size = query_tokens.shape[0]
# Text representation of query and context
query_list, context_list = [], []
for i in range(local_batch_size):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens is not None:
neg_context_tokens = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_mask)
neg_context_types = check_and_append_tensor_for_gather(group,
rank, world_size, neg_context_types)
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types])
# Forward model.
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
args = get_args()
local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank()
# recall we assert that model_parallel_size == 1
global_batch_size = world_size * local_batch_size
query_logits, context_logits = output_tensor
if world_size > 1:
input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == \
context_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = context_logits
all_context_logits = torch.cat(tensor_list, dim=0).contiguous()
# Query tensors
input_ = torch.empty_like(query_logits).copy_(\
query_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
assert tensor_list[rank].sum().item() == query_logits.sum().item()
# Preserves the gradient
tensor_list[rank] = query_logits
all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
else:
all_query_logits = query_logits
all_context_logits = context_logits
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# Scaling the retrieval scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
if args.train_with_neg:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels = []
local_context_size = context_tokens.shape[0]
for i in range(world_size):
j = i * local_context_size
labels.extend(list(range(j, j + local_batch_size)))
labels = torch.LongTensor(labels).cuda()
assert len(labels) == global_batch_size
else:
labels = torch.arange(global_batch_size).long().cuda()
# Cross-entropy loss.
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
max_score, max_idxs = torch.max(softmax_scores, 1)
correct_predictions_count = (max_idxs == labels).sum().float()
# Reduce loss for logging.
reduced_loss = average_losses_across_data_parallel_group([loss, \
correct_predictions_count])
# Loss scaling for correct losses in Supervised Retrieval
loss = loss * mpu.get_data_parallel_world_size()
return loss, {'lm loss': reduced_loss[0],
'correct_prediction_count': reduced_loss[1]}
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = Dataset('training',
args.train_data,
tokenizer,
args.retriever_seq_length,
evaluate=False)
valid_dataset = Dataset('validation',
args.valid_data,
tokenizer,
args.retriever_seq_length,
evaluate=True)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task))
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model
def single_dataset_provider(datapath):
args = get_args()
tokenizer = get_tokenizer()
name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name,
datapath,
tokenizer,
args.retriever_seq_length,
evaluate=True)
def metrics_func_provider():
"""Provide metrics callback function."""
return accuracy_func_provider(single_dataset_provider)
"""Finetune/evaluate."""
finetune(train_valid_datasets_provider,
model_provider,
forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn)
def main():
args = get_args()
if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
else:
raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task))
orqa(Dataset)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""
Data Loader for Google NQ dataset
"""
from abc import ABC
import csv
from collections import OrderedDict
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, BatchSampler
from megatron import print_rank_0, get_args, get_tokenizer
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_nq_dataset(qa_data, split):
args = get_args()
tokenizer = get_tokenizer()
dataset = NQDataset('Google NQ {} Split'.format(split),
'Google Natural Questions',
qa_data,
tokenizer,
args.retriever_seq_length)
return dataset
def process_nq_batch(batch):
query_tokens = batch['token_ids'].long().cuda()
query_mask = (batch['token_mask'] < 0.5).cuda()
query_types = batch['token_types'].long().cuda()
query_len = batch['seq_len'].long().cuda()
reference = batch['reference']
return query_tokens, query_mask, query_types, query_len, reference
class CustomDataLoader(DataLoader):
def __init__(self, dataset, eval=False, **kwargs):
if kwargs.get('collate_fn', None) is None:
kwargs['collate_fn'] = self._collate_fn
self.eval = eval
super().__init__(dataset, **kwargs)
def _collate_fn(self, batch_data):
# generate batch
batch_size = len(batch_data)
tensorized = OrderedDict()
for d in batch_data:
for k, v in d.items():
tensorized.setdefault(k, []).append(v)
assert len(tensorized) == 5
tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids'])
tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask'])
tensorized['token_types'] = torch.LongTensor(tensorized['token_types'])
tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len'])
return tensorized
def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args = get_args()
if micro_batch_size is None:
micro_batch_size = args.micro_batch_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
# importantly, drop_last must be False to get all the data.
batch_sampler = BatchSampler(sampler,
batch_size=micro_batch_size,
drop_last=False)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = CustomDataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
return data_loader
def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids = tokenizer.tokenize(src_text)
return build_tokens_types_paddings_from_ids(src_text_ids,
max_seq_length,
tokenizer.cls,
tokenizer.sep,
tokenizer.pad)
def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \
sep_id, pad_id):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(src_ids)
enc_ids.extend(src_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
return enc_ids, tokentypes_enc, num_tokens_enc
def build_sample(token_ids, token_types, num_tokens, reference):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids = np.array(token_ids, dtype=np.int64)
token_types = np.array(token_types, dtype=np.int64)
token_mask = make_attention_mask(token_ids, token_ids)
sample = ({
'token_ids': token_ids,
'token_mask': token_mask,
'token_types': token_types,
'seq_len': num_tokens,
'reference': reference
})
return sample
class NQDataset(ABC, Dataset):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def __init__(self, task_name, dataset_name, datapath,
tokenizer, max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
print_rank_0(datapath)
self.samples = self.process_samples_from_single_path(datapath)
print_rank_0(' >> total number of samples: {}'.format(\
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
raw_sample = self.samples[idx]
ques_tokens, tokentypes_enc, num_tokens_ques = \
build_tokens_types_paddings_from_text(raw_sample['question'],
self.tokenizer, self.max_seq_length)
sample = build_sample(ques_tokens,
tokentypes_enc,
num_tokens_ques,
raw_sample['answers'])
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
samples = []
total = 0
with open(filename, 'r') as ifile:
reader = csv.reader(ifile, delimiter='\t')
for row in reader:
question = row[0]
answers = eval(row[1])
sample = {'question': question, 'answers': answers}
total += 1
samples.append(sample)
if total % 1000 == 0:
print_rank_0(' > processed {} so far ...'.format(total))
print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import collections
import logging
import string
import unicodedata
from functools import partial
from multiprocessing import Pool as ProcessPool
from typing import Tuple, List, Dict
import regex as re
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
logger = logging.getLogger(__name__)
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\
'questions_doc_hits'])
def calculate_matches(all_docs: Dict[object, Tuple[str, str]],
answers: List[List[str]], closest_docs: List[Tuple[List[object],
List[float]]], workers_num: int, match_type: str) -> QAMatchStats:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global dpr_all_documents
dpr_all_documents = all_docs
tok_opts = {}
tokenizer = SimpleTokenizer(**tok_opts)
processes = ProcessPool(
processes=workers_num,
)
logger.info('Matching answers in top docs...')
get_score_partial = partial(check_answer, match_type=match_type,
tokenizer=tokenizer)
questions_answers_docs = zip(answers, closest_docs)
scores = processes.map(get_score_partial, questions_answers_docs)
logger.info('Per question validation results len=%d', len(scores))
n_docs = len(closest_docs[0][0])
top_k_hits = [0] * n_docs
for question_hits in scores:
best_hit = next((i for i, x in enumerate(question_hits) if x), None)
if best_hit is not None:
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
return QAMatchStats(top_k_hits, scores)
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers, (doc_ids, doc_scores) = questions_answers_docs
global dpr_all_documents
hits = []
for i, doc_id in enumerate(doc_ids):
doc = dpr_all_documents[doc_id]
text = doc[0]
answer_found = False
if text is None: # cannot find the document for some reason
logger.warning("no doc in db")
hits.append(False)
continue
if has_answer(answers, text, tokenizer, match_type):
answer_found = True
hits.append(answer_found)
return hits
def has_answer(answers, text, tokenizer, match_type) -> bool:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text = _normalize(text)
if match_type == 'string':
# Answer is a list of possible strings
text = tokenizer.tokenize(text).words(uncased=True)
for single_answer in answers:
single_answer = _normalize(single_answer)
single_answer = tokenizer.tokenize(single_answer)
single_answer = single_answer.words(uncased=True)
for i in range(0, len(text) - len(single_answer) + 1):
if single_answer == text[i: i + len(single_answer)]:
return True
elif match_type == 'regex':
# Answer is a regex
for single_answer in answers:
single_answer = _normalize(single_answer)
if regex_match(text, single_answer):
return True
return False
def regex_match(text, pattern):
"""Test if a regex pattern is contained within a text."""
try:
pattern = re.compile(
pattern,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
)
except BaseException:
return False
return pattern.search(text) is not None
# function for the reader model answer validation
def exact_match_score(prediction, ground_truth):
return _normalize_answer(prediction) == _normalize_answer(ground_truth)
def _normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _normalize(text):
return unicodedata.normalize('NFD', text)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import copy
import logging
import regex
import spacy
logger = logging.getLogger(__name__)
class Tokens(object):
"""A class to represent a list of tokenized text."""
TEXT = 0
TEXT_WS = 1
SPAN = 2
POS = 3
LEMMA = 4
NER = 5
def __init__(self, data, annotators, opts=None):
self.data = data
self.annotators = annotators
self.opts = opts or {}
def __len__(self):
"""The number of tokens."""
return len(self.data)
def slice(self, i=None, j=None):
"""Return a view of the list of tokens from [i, j)."""
new_tokens = copy.copy(self)
new_tokens.data = self.data[i: j]
return new_tokens
def untokenize(self):
"""Returns the original text (with whitespace reinserted)."""
return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
def words(self, uncased=False):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if uncased:
return [t[self.TEXT].lower() for t in self.data]
else:
return [t[self.TEXT] for t in self.data]
def offsets(self):
"""Returns a list of [start, end) character offsets of each token."""
return [t[self.SPAN] for t in self.data]
def pos(self):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if 'pos' not in self.annotators:
return None
return [t[self.POS] for t in self.data]
def lemmas(self):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if 'lemma' not in self.annotators:
return None
return [t[self.LEMMA] for t in self.data]
def entities(self):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if 'ner' not in self.annotators:
return None
return [t[self.NER] for t in self.data]
def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def _skip(gram):
if not filter_fn:
return False
return filter_fn(gram)
words = self.words(uncased)
ngrams = [(s, e + 1)
for s in range(len(words))
for e in range(s, min(s + n, len(words)))
if not _skip(words[s:e + 1])]
# Concatenate into strings
if as_strings:
ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
return ngrams
def entity_groups(self):
"""Group consecutive entity tokens with the same NER tag."""
entities = self.entities()
if not entities:
return None
non_ent = self.opts.get('non_ent', 'O')
groups = []
idx = 0
while idx < len(entities):
ner_tag = entities[idx]
# Check for entity tag
if ner_tag != non_ent:
# Chomp the sequence
start = idx
while (idx < len(entities) and entities[idx] == ner_tag):
idx += 1
groups.append((self.slice(start, idx).untokenize(), ner_tag))
else:
idx += 1
return groups
class Tokenizer(object):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def tokenize(self, text):
raise NotImplementedError
def shutdown(self):
pass
def __del__(self):
self.shutdown()
class SimpleTokenizer(Tokenizer):
ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
NON_WS = r'[^\p{Z}\p{C}]'
def __init__(self, **kwargs):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
'(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
)
if len(kwargs.get('annotators', {})) > 0:
logger.warning('%s only tokenizes! Skipping annotators: %s' %
(type(self).__name__, kwargs.get('annotators')))
self.annotators = set()
def tokenize(self, text):
data = []
matches = [m for m in self._regexp.finditer(text)]
for i in range(len(matches)):
# Get text
token = matches[i].group()
# Get whitespace
span = matches[i].span()
start_ws = span[0]
if i + 1 < len(matches):
end_ws = matches[i + 1].span()[0]
else:
end_ws = span[1]
# Format data
data.append((
token,
text[start_ws: end_ws],
span,
))
return Tokens(data, self.annotators)
class SpacyTokenizer(Tokenizer):
def __init__(self, **kwargs):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model = kwargs.get('model', 'en')
self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
nlp_kwargs = {'parser': False}
if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
nlp_kwargs['tagger'] = False
if 'ner' not in self.annotators:
nlp_kwargs['entity'] = False
self.nlp = spacy.load(model, **nlp_kwargs)
def tokenize(self, text):
# We don't treat new lines as tokens.
clean_text = text.replace('\n', ' ')
tokens = self.nlp.tokenizer(clean_text)
if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
self.nlp.tagger(tokens)
if 'ner' in self.annotators:
self.nlp.entity(tokens)
data = []
for i in range(len(tokens)):
# Get whitespace
start_ws = tokens[i].idx
if i + 1 < len(tokens):
end_ws = tokens[i + 1].idx
else:
end_ws = tokens[i].idx + len(tokens[i].text)
data.append((
tokens[i].text,
text[start_ws: end_ws],
(tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
tokens[i].tag_,
tokens[i].lemma_,
tokens[i].ent_type_,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return Tokens(data, self.annotators, opts={'non_ent': ''})
import glob
import json
import os
import time
from torch.utils.data import Dataset
from megatron import print_rank_0
from tasks.data_utils import build_sample
from tasks.data_utils import build_tokens_types_paddings_from_ids
from tasks.data_utils import clean_text
NUM_CHOICES = 4
MAX_QA_LENGTH = 128
class RaceDataset(Dataset):
def __init__(self, dataset_name, datapaths, tokenizer, max_seq_length,
max_qa_length=MAX_QA_LENGTH):
self.dataset_name = dataset_name
print_rank_0(' > building RACE dataset for {}:'.format(
self.dataset_name))
string = ' > paths:'
for path in datapaths:
string += ' ' + path
print_rank_0(string)
self.samples = []
for datapath in datapaths:
self.samples.extend(process_single_datapath(datapath, tokenizer,
max_qa_length,
max_seq_length))
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self.sample_multiplier = NUM_CHOICES
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
print_rank_0(' > working on {}'.format(datapath))
start_time = time.time()
# Get list of files.
filenames = glob.glob(os.path.join(datapath, '*.txt'))
samples = []
num_docs = 0
num_questions = 0
num_samples = 0
# Load all the files
for filename in filenames:
with open(filename, 'r') as f:
for line in f:
data = json.loads(line)
num_docs += 1
context = data["article"]
questions = data["questions"]
choices = data["options"]
answers = data["answers"]
# Check the length.
assert len(questions) == len(answers)
assert len(questions) == len(choices)
# Context: clean up and convert to ids.
context = clean_text(context)
context_ids = tokenizer.tokenize(context)
# Loop over questions.
for qi, question in enumerate(questions):
num_questions += 1
# Label.
label = ord(answers[qi]) - ord("A")
assert label >= 0
assert label < NUM_CHOICES
assert len(choices[qi]) == NUM_CHOICES
# For each question, build num-choices samples.
ids_list = []
types_list = []
paddings_list = []
for ci in range(NUM_CHOICES):
choice = choices[qi][ci]
# Merge with choice.
if "_" in question:
qa = question.replace("_", choice)
else:
qa = " ".join([question, choice])
# Clean QA.
qa = clean_text(qa)
# Tokenize.
qa_ids = tokenizer.tokenize(qa)
# Trim if needed.
if len(qa_ids) > max_qa_length:
qa_ids = qa_ids[0:max_qa_length]
# Build the sample.
ids, types, paddings \
= build_tokens_types_paddings_from_ids(
qa_ids, context_ids, max_seq_length,
tokenizer.cls, tokenizer.sep, tokenizer.pad)
ids_list.append(ids)
types_list.append(types)
paddings_list.append(paddings)
# Convert to numpy and add to samples
samples.append(build_sample(ids_list, types_list,
paddings_list, label,
num_samples))
num_samples += 1
elapsed_time = time.time() - start_time
print_rank_0(' > processed {} document, {} questions, and {} samples'
' in {:.2f} seconds'.format(num_docs, num_questions,
num_samples, elapsed_time))
return samples
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Race."""
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model.multiple_choice import MultipleChoice
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
from tasks.race.data import RaceDataset
def train_valid_datasets_provider():
"""Provide train and validation datasets."""
args = get_args()
tokenizer = get_tokenizer()
train_dataset = RaceDataset('training', args.train_data,
tokenizer, args.seq_length)
valid_dataset = RaceDataset('validation', args.valid_data,
tokenizer, args.seq_length)
return train_dataset, valid_dataset
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building multichoice model for RACE ...')
model = MultipleChoice(num_tokentypes=2,
pre_process=pre_process,
post_process=post_process)
return model
def metrics_func_provider():
"""Privde metrics callback function."""
args = get_args()
tokenizer = get_tokenizer()
def single_dataset_provider(datapath):
name = datapath.split('RACE')[-1].strip('/').replace('/', '-')
return RaceDataset(name, [datapath], tokenizer, args.seq_length)
return accuracy_func_provider(single_dataset_provider)
def main():
finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback_provider=metrics_func_provider)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision-classification finetuning/evaluation."""
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0
from megatron.model.vision.classification import VitClassificationModel
from megatron.data.vit_dataset import build_train_valid_datasets
from tasks.vision.classification.eval_utils import accuracy_func_provider
from tasks.vision.finetune_utils import finetune
from megatron.utils import average_losses_across_data_parallel_group
def classification():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w),
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
print_rank_0("building classification model for ImageNet ...")
return VitClassificationModel(num_classes=args.num_classes, finetune=True,
pre_process=pre_process, post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def cross_entropy_loss_func(labels, output_tensor):
logits = output_tensor
# Cross-entropy loss.
loss = F.cross_entropy(logits.contiguous().float(), labels)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator", log_level=2).start()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, labels)
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
classification()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Evaluation utilities."""
import os
from functools import partial
import torch
from megatron import get_args
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from megatron.schedules import get_forward_backward_func
from tasks.vision.finetune_utils import build_data_loader
from tasks.vision.finetune_utils import process_batch
from torchvision import datasets, transforms
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
data_path = args.data_path
crop_size = (args.img_h, args.img_w)
# Build dataloaders.
val_data_path = data_path[1]
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
transform_val = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
)
dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
dataloader = build_data_loader(
dataset,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
correct, total = calculate_correct_answers(model, dataloader, epoch)
percent = float(correct) * 100.0 / float(total)
print_rank_last(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %".format(epoch, correct, total, percent)
)
return metrics_func
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
predicted = torch.argmax(logits, dim=-1)
corrects = (predicted == labels).float()
# Add to the counters.
loss_dict['total'] = labels.size(0)
loss_dict['correct'] = corrects.sum().item()
return 0, loss_dict
#defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad():
# For all the batches in the dataset.
total = 0
correct = 0
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
optimizer=None, timers=None, forward_only=True)
for loss_dict in loss_dicts:
total += loss_dict['total']
correct += loss_dict['correct']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
unreduced = torch.cuda.LongTensor([correct, total])
torch.distributed.all_reduce(unreduced,
group=mpu.get_data_parallel_group())
# Print on screen.
correct_ans = unreduced[0].item()
total_count = unreduced[1].item()
return correct_ans, total_count
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Finetune utilities."""
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import utils
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
from megatron.training import setup_model_and_optimizer
from megatron.training import train_step
from megatron.training import training_log
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import average_losses_across_data_parallel_group, print_params_min_max_norm
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.core.enums import ModelType
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
return images, labels
def build_data_loader(dataset, micro_batch_size,
num_workers, drop_last, shuffle):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True,
)
return data_loader
def _build_infinite_size_dataloader(dataloader):
"""Build a looped dataloader with infinite size."""
iterator = dataloader.__iter__()
while True:
try:
yield iterator.__next__()
except StopIteration:
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, False, True)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, True, False)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args.orig_micro_batch_size = args.micro_batch_size
args.orig_global_batch_size = args.global_batch_size
return train_dataloader, valid_dataloader
def _train(
model,
optimizer,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
process_non_loss_data_func=None
):
"""Train the model."""
args = get_args()
timers = get_timers()
# Turn on training mode which enables dropout.
for m in model:
m.train()
# Tracking loss.
losses_dict_sum = {}
# Starting epoch and iteration
start_epoch = args.iteration // args.train_iters_per_epoch
start_iteration = args.iteration % args.train_iters_per_epoch
iteration = args.iteration
# Memory reporting flag.
report_memory_flag = True
# For each remaining epoch
timers("interval-time", log_level=0).start(barrier=True)
for epoch in range(start_epoch, args.epochs):
print_rank_0("working on epoch {} ...".format(epoch + 1))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader.sampler.set_epoch(args.seed + epoch)
train_dataloader.dataset.set_epoch(epoch)
# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Ignore the iterations before starting value
if iteration_ < start_iteration:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
forward_step, batch, model, optimizer, opt_param_scheduler
)
iteration += 1
# Logging.
params_norm = None
report_memory_flag = training_log(
losses_dict,
losses_dict_sum,
optimizer.param_groups[0]["lr"],
iteration,
optimizer.get_loss_scale().item(),
report_memory_flag,
skipped_iter,
grad_norm,
params_norm,
num_zeros_in_grad
)
# Autoresume
if args.adlr_autoresume and \
iteration % args.adlr_autoresume_interval == 0:
check_adlr_autoresume_termination(iteration, model, optimizer,
opt_param_scheduler)
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer,
opt_param_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0:
prefix = "iteration {}".format(iteration)
evaluate_and_print_results(
prefix,
forward_step,
valid_dataloader,
model,
iteration,
process_non_loss_data_func,
False,
)
# Callback at the end of each epoch.
if end_of_epoch_callback is not None:
end_of_epoch_callback(model, epoch)
def finetune(
train_valid_datasets_provider,
model_provider,
forward_step,
model_type=ModelType.encoder_or_decoder,
process_non_loss_data_func=None,
end_of_epoch_callback_provider=None,
):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
# Train and validation data loaders.
timers("train/valid/test dataset/dataloder", log_level=0).start()
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset
)
timers("train/valid/test dataset/dataloder").stop()
# Build calback function.
timers("callback function", log_level=0).start()
end_of_epoch_callback = None
if end_of_epoch_callback_provider is not None:
end_of_epoch_callback = end_of_epoch_callback_provider()
timers("callback function").stop()
# Build model, optimizer and learning rate scheduler.
timers("model and optimizer", log_level=0).start()
model, optimizer, opt_param_scheduler = \
setup_model_and_optimizer(
model_provider,
model_type,
scale_lr_cond=lambda name, param: ".head." in name,
lr_mult=args.head_lr_mult)
timers("model and optimizer").stop()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers("pretrained checkpoint", log_level=0).start(barrier=True)
if args.iteration == 0 and args.pretrained_checkpoint is not None:
if args.pretrained_checkpoint_type == 'default':
original_load = args.load
args.load = args.pretrained_checkpoint
_ = load_checkpoint(model, None, None, strict=False)
args.load = original_load
elif args.pretrained_checkpoint_type == 'external':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
elif args.pretrained_checkpoint_type == 'constrastive':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
state_dict = state_dict["model"]
state_dict = {k.replace("teacher.backbone.", ""): v
for k, v in state_dict.items()
if k.startswith("teacher.backbone.")}
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
else:
raise Exception("pretrained checkpoint type {} not supported".format(args.pretrained_checkpoint_type))
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer.reload_model_params()
timers("pretrained checkpoint").stop()
# Print setup timing.
print_rank_0("done with setups ...")
timers.log(
[
"train/valid/test dataset/dataloder",
"callback function",
"model and optimizer",
"pretrained checkpoint",
]
)
print_rank_0("training ...")
# Finetune the model.
if args.epochs > 0:
_train(
model,
optimizer,
opt_param_scheduler,
forward_step,
train_dataloader,
valid_dataloader,
end_of_epoch_callback,
process_non_loss_data_func,
)
# Or just evaluate.
else:
if end_of_epoch_callback is not None:
print_rank_0("evaluation only mode, setting epoch to -1")
end_of_epoch_callback(model, epoch=-1)
print_rank_0("done :-)")
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Main tasks functionality."""
import os
import sys
sys.path.append(
os.path.abspath(
os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir),
os.path.pardir,
)
)
)
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title="tasks")
group.add_argument('--task', type=str, default='segment',
choices=['classify', 'segment_setr', 'segment_segformer'],
help='task name.')
group.add_argument("--epochs", type=int, default=None,
help="Number of finetunning epochs. Zero results in "
"evaluation only.")
group.add_argument('--pretrained-checkpoint-type', type=str, default='default',
choices=['default', 'external', 'constrastive'],
help='Type of pretrained checkpoint')
group.add_argument("--pretrained-checkpoint", type=str, default=None,
help="Pretrained checkpoint used for finetunning.")
group.add_argument('--seg-stride', type=int, default=None,
help='sliding window stride during evaluation')
return parser
if __name__ == "__main__":
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.task == 'classify':
from tasks.vision.classification.classification import main
main()
elif args.task == 'segment_setr':
from tasks.vision.segmentation.finetune_setr import main
main()
elif args.task == 'segment_segformer':
from tasks.vision.segmentation.finetune_segformer import main
main()
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
# modified it to change max label index from 255 to 19 (num_classes)
import torch
import json
import os
from collections import namedtuple
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np
from torchvision.datasets.utils import extract_archive, verify_str_arg, iterable_to_str
from torchvision.datasets import VisionDataset
from PIL import Image
from megatron import print_rank_0
class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes = 19
ignore_index = 19
color_table = torch.tensor(
[[128, 64, 128],
[244, 35, 232],
[70, 70, 70],
[102, 102, 156],
[190, 153, 153],
[153, 153, 153],
[250, 170, 30],
[220, 220, 0],
[107, 142, 35],
[152, 251, 152],
[70, 130, 180],
[220, 20, 60],
[255, 0, 0],
[0, 0, 142],
[0, 0, 70],
[0, 60, 100],
[0, 80, 100],
[0, 0, 230],
[119, 11, 32],
[0, 0, 0]], dtype=torch.float, device='cuda')
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id',
'category', 'category_id', 'has_instances', 'ignore_in_eval', 'color'])
classes = [
CityscapesClass('unlabeled', 0, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('ego vehicle', 1, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('rectification border', 2, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('out of roi', 3, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('static', 4, 19, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('dynamic', 5, 19, 'void', 0, False, True, (111, 74, 0)),
CityscapesClass('ground', 6, 19, 'void', 0, False, True, (81, 0, 81)),
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
CityscapesClass('parking', 9, 19, 'flat', 1, False, True, (250, 170, 160)),
CityscapesClass('rail track', 10, 19, 'flat', 1, False, True, (230, 150, 140)),
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
CityscapesClass('guard rail', 14, 19, 'construction', 2, False, True, (180, 165, 180)),
CityscapesClass('bridge', 15, 19, 'construction', 2, False, True, (150, 100, 100)),
CityscapesClass('tunnel', 16, 19, 'construction', 2, False, True, (150, 120, 90)),
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
CityscapesClass('polegroup', 18, 19, 'object', 3, False, True, (153, 153, 153)),
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
CityscapesClass('caravan', 29, 19, 'vehicle', 7, True, True, (0, 0, 90)),
CityscapesClass('trailer', 30, 19, 'vehicle', 7, True, True, (0, 0, 110)),
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
]
# label2trainid
label2trainid = { label.id : label.train_id for label in classes}
def __init__(
self,
root: str,
split: str = "train",
mode: str = "fine",
resolution: int = 1024,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(Cityscapes, self).__init__(root, transforms, transform, target_transform)
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit', split)
self.targets_dir = os.path.join(self.root, 'gtFine_trainvaltest/gtFine', split)
self.split = split
self.resolution = resolution
self.images = []
self.targets = []
for city in sorted(os.listdir(self.images_dir)):
img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city)
for file_name in os.listdir(img_dir):
target_name = '{}_{}_labelIds.png'.format(file_name.split('_leftImg8bit')[0], self.mode)
self.images.append(os.path.join(img_dir, file_name))
self.targets.append(os.path.join(target_dir, target_name))
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.targets[index])
target = np.array(target)
target_copy = target.copy()
for k, v in Cityscapes.label2trainid.items():
binary_target = (target == k)
target_copy[binary_target] = v
target = target_copy
target = Image.fromarray(target.astype(np.uint8))
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
# len(self.images)
return len(self.images)
import random
import os
import math
import mmcv
import torch
import numpy as np
import torchvision.transforms as T
from torchvision import datasets
from torch.utils.data import Dataset
from megatron.data.autoaugment import ImageNetPolicy
from tasks.vision.segmentation.cityscapes import Cityscapes
import tasks.vision.segmentation.transforms as ET
from megatron.data.autoaugment import ImageNetPolicy
from megatron import get_args
from PIL import Image, ImageOps
class VitSegmentationJointTransform():
def __init__(self, train=True, resolution=None):
self.train = train
if self.train:
self.transform0 = ET.RandomSizeAndCrop(resolution)
self.transform1 = ET.RandomHorizontallyFlip()
def __call__(self, img, mask):
if self.train:
img, mask = self.transform0(img, mask)
img, mask = self.transform1(img, mask)
return img, mask
class VitSegmentationImageTransform():
def __init__(self, train=True, resolution=None):
args = get_args()
self.train = train
assert args.fp16 or args.bf16
self.data_type = torch.half if args.fp16 else torch.bfloat16
self.mean_std = args.mean_std
if self.train:
assert resolution is not None
self.transform = T.Compose([
ET.PhotoMetricDistortion(),
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
else:
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(*self.mean_std),
T.ConvertImageDtype(self.data_type)
])
def __call__(self, input):
output = self.transform(input)
return output
class VitSegmentationTargetTransform():
def __init__(self, train=True, resolution=None):
self.train = train
def __call__(self, input):
output = torch.from_numpy(np.array(input, dtype=np.int32)).long()
return output
class RandomSeedSegmentationDataset(Dataset):
def __init__(self,
dataset,
joint_transform,
image_transform,
target_transform):
args = get_args()
self.base_seed = args.seed
self.curr_seed = self.base_seed
self.dataset = dataset
self.joint_transform = joint_transform
self.image_transform = image_transform
self.target_transform = target_transform
def __len__(self):
return len(self.dataset)
def set_epoch(self, epoch):
self.curr_seed = self.base_seed + 100 * epoch
def __getitem__(self, idx):
seed = idx + self.curr_seed
img, mask = self.dataset[idx]
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
img, mask = self.joint_transform(img, mask)
img = self.image_transform(img)
mask = self.target_transform(mask)
return img, mask
def build_cityscapes_train_valid_datasets(data_path, image_size):
args = get_args()
args.num_classes = Cityscapes.num_classes
args.ignore_index = Cityscapes.ignore_index
args.color_table = Cityscapes.color_table
args.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_joint_transform = \
VitSegmentationJointTransform(train=True, resolution=image_size)
val_joint_transform = \
VitSegmentationJointTransform(train=False, resolution=image_size)
train_image_transform = \
VitSegmentationImageTransform(train=True, resolution=image_size)
val_image_transform = \
VitSegmentationImageTransform(train=False, resolution=image_size)
train_target_transform = \
VitSegmentationTargetTransform(train=True, resolution=image_size)
val_target_transform = \
VitSegmentationTargetTransform(train=False, resolution=image_size)
# training dataset
train_data = Cityscapes(
root=data_path[0],
split='train',
mode='fine',
resolution=image_size
)
train_data = RandomSeedSegmentationDataset(
train_data,
joint_transform=train_joint_transform,
image_transform=train_image_transform,
target_transform=train_target_transform)
# validation dataset
val_data = Cityscapes(
root=data_path[0],
split='val',
mode='fine',
resolution=image_size
)
val_data = RandomSeedSegmentationDataset(
val_data,
joint_transform=val_joint_transform,
image_transform=val_image_transform,
target_transform=val_target_transform)
return train_data, val_data
def build_train_valid_datasets(data_path, image_size):
return build_cityscapes_train_valid_datasets(data_path, image_size)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision-classification finetuning/evaluation."""
import numpy as np
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
from megatron.schedules import get_forward_backward_func
from tasks.vision.segmentation.data import build_train_valid_datasets
from tasks.vision.segmentation.seg_models import SegformerSegmentationModel
from megatron.model.vision.utils import resize
def calculate_iou(hist_data):
acc = np.diag(hist_data).sum() / hist_data.sum()
acc_cls = np.diag(hist_data) / hist_data.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
divisor = hist_data.sum(axis=1) + hist_data.sum(axis=0) - \
np.diag(hist_data)
iu = np.diag(hist_data) / divisor
return iu, acc, acc_cls
def fast_hist(pred, gtruth, num_classes):
# mask indicates pixels we care about
mask = (gtruth >= 0) & (gtruth < num_classes)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist = np.bincount(num_classes * gtruth[mask].astype(int) + pred[mask],
minlength=num_classes ** 2)
hist = hist.reshape(num_classes, num_classes)
return hist
def segmentation():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
model = SegformerSegmentationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
print_rank_0("model = {}".format(model))
return model
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
masks = batch[1].cuda().contiguous()
return images, masks
def calculate_weight(masks, num_classes):
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
hist_norm = bins.float()/bins.sum()
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
return hist
def cross_entropy_loss_func(images, masks, output_tensor,
non_loss_data=False):
args = get_args()
ignore_index = args.ignore_index
color_table = args.color_table
logits = output_tensor.contiguous().float()
logits = resize(logits, size=masks.shape[1:],
mode='bilinear', align_corners=False)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss = F.cross_entropy(logits, masks, ignore_index=ignore_index)
if not non_loss_data:
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
else:
seg_mask = logits.argmax(dim=1)
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
return torch.cat((images, output_mask, gt_mask), dim=2), loss
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
timers = get_timers()
# Get the batch.
timers("batch generator", log_level=2).start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
else:
batch_ = batch
images, masks = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, images, masks)
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, output_tensor):
args = get_args()
logits = output_tensor
logits = resize(logits, size=labels.shape[1:],
mode='bilinear', align_corners=False)
loss_dict = {}
# Compute the correct answers.
probs = logits.contiguous().float().softmax(dim=1)
max_probs, preds = torch.max(probs, 1)
preds = preds.cpu().numpy()
performs = fast_hist(preds.flatten(),
labels.cpu().numpy().flatten(),
args.ignore_index)
loss_dict['performs'] = performs
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
# Forward model.
output_tensor = model(images)
return output_tensor, partial(loss_func, labels)
with torch.no_grad():
# For all the batches in the dataset.
performs = None
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step,
batch, model,
optimizer=None,
timers=None,
forward_only=True)
for loss_dict in loss_dicts:
if performs is None:
performs = loss_dict['performs']
else:
performs += loss_dict['performs']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
performs_tensor = torch.cuda.FloatTensor(performs)
torch.distributed.all_reduce(performs_tensor,
group=mpu.get_data_parallel_group())
hist = performs_tensor.cpu().numpy()
iu, acc, acc_cls = calculate_iou(hist)
miou = np.nanmean(iu)
return iu, miou
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
dataloader = build_data_loader(
valid_ds,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
iou, miou = calculate_correct_answers(model, dataloader, epoch)
print_rank_last(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
)
return metrics_func
def dump_output_data(data, iteration, writer):
for (output_tb, loss) in data:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer.add_images("image-outputseg-realseg", output_tb,
global_step=None, walltime=None,
dataformats='NCHW')
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
process_non_loss_data_func=dump_output_data,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
segmentation()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision-classification finetuning/evaluation."""
import torch
import torch.nn.functional as F
from functools import partial
from megatron import get_args, get_timers
from megatron import print_rank_0, print_rank_last
from megatron.core import mpu
from tasks.vision.finetune_utils import finetune
from tasks.vision.finetune_utils import build_data_loader
from megatron.utils import average_losses_across_data_parallel_group
from megatron.schedules import get_forward_backward_func
from tasks.vision.segmentation.metrics import CFMatrix
from tasks.vision.segmentation.data import build_train_valid_datasets
from tasks.vision.segmentation.seg_models import SetrSegmentationModel
from tasks.vision.segmentation.utils import slidingcrops, slidingjoins
def segmentation():
def train_valid_datasets_provider():
"""Build train and validation dataset."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
return train_ds, valid_ds
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
args = get_args()
return SetrSegmentationModel(num_classes=args.num_classes,
pre_process=pre_process,
post_process=post_process)
def process_batch(batch):
"""Process batch and produce inputs for the model."""
images = batch[0].cuda().contiguous()
masks = batch[1].cuda().contiguous()
return images, masks
def calculate_weight(masks, num_classes):
bins = torch.histc(masks, bins=num_classes, min=0.0, max=num_classes)
hist_norm = bins.float()/bins.sum()
hist = ((bins != 0).float() * (1. - hist_norm)) + 1.0
return hist
def cross_entropy_loss_func(images, masks, output_tensor, non_loss_data=False):
args = get_args()
ignore_index = args.ignore_index
color_table = args.color_table
weight = calculate_weight(masks, args.num_classes)
logits = output_tensor.contiguous().float()
loss = F.cross_entropy(logits, masks, weight=weight, ignore_index=ignore_index)
if not non_loss_data:
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
else:
seg_mask = logits.argmax(dim=1)
output_mask = F.embedding(seg_mask, color_table).permute(0, 3, 1, 2)
gt_mask = F.embedding(masks, color_table).permute(0, 3, 1, 2)
return torch.cat((images, output_mask, gt_mask), dim=2), loss
def _cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers()
# Get the batch.
timers("batch generator", log_level=2).start()
import types
if isinstance(batch, types.GeneratorType):
batch_ = next(batch)
else:
batch_ = batch
images, masks = process_batch(batch_)
timers("batch generator").stop()
# Forward model.
if not model.training:
images, masks, _, _ = slidingcrops(images, masks)
#print_rank_0("images size = {}".format(images.size()))
if not model.training:
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
else:
output_tensor = model(images)
return output_tensor, partial(cross_entropy_loss_func, images, masks)
def calculate_correct_answers(model, dataloader, epoch):
"""Calculate correct over total answers"""
forward_backward_func = get_forward_backward_func()
for m in model:
m.eval()
def loss_func(labels, slices_info, img_size, output_tensor):
args = get_args()
logits = output_tensor
loss_dict = {}
# Compute the correct answers.
probs = logits.contiguous().float().softmax(dim=1)
max_probs, preds = torch.max(probs, 1)
preds = preds.int()
preds, labels = slidingjoins(preds, max_probs, labels, slices_info, img_size)
_, performs = CFMatrix()(preds, labels, args.ignore_index)
loss_dict['performs'] = performs
return 0, loss_dict
# defined inside to capture output_predictions
def correct_answers_forward_step(batch, model):
args = get_args()
try:
batch_ = next(batch)
except BaseException:
batch_ = batch
images, labels = process_batch(batch_)
assert not model.training
images, labels, slices_info, img_size = slidingcrops(images, labels)
# Forward model.
output_tensor = torch.cat([model(image) for image in torch.split(images, args.micro_batch_size)])
return output_tensor, partial(loss_func, labels, slices_info, img_size)
with torch.no_grad():
# For all the batches in the dataset.
performs = None
for _, batch in enumerate(dataloader):
loss_dicts = forward_backward_func(correct_answers_forward_step,
batch, model,
optimizer=None,
timers=None,
forward_only=True)
for loss_dict in loss_dicts:
if performs is None:
performs = loss_dict['performs']
else:
performs += loss_dict['performs']
for m in model:
m.train()
# Reduce.
if mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(performs,
group=mpu.get_data_parallel_group())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive = performs[:, 0]
false_positive = performs[:, 1]
false_negative = performs[:, 3]
iou = true_positive / (true_positive + false_positive + false_negative)
miou = iou[~torch.isnan(iou)].mean()
return iou.tolist(), miou.item()
def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args = get_args()
train_ds, valid_ds = build_train_valid_datasets(
data_path=args.data_path,
image_size=(args.img_h, args.img_w)
)
dataloader = build_data_loader(
valid_ds,
args.micro_batch_size,
num_workers=args.num_workers,
drop_last=(mpu.get_data_parallel_world_size() > 1),
shuffle=False
)
def metrics_func(model, epoch):
print_rank_0("calculating metrics ...")
iou, miou = calculate_correct_answers(model, dataloader, epoch)
print_rank_last(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %".format(epoch, iou, miou*100.0)
)
return metrics_func
def dump_output_data(data, iteration, writer):
for (output_tb, loss) in data:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer.add_images("image-outputseg-realseg", output_tb,
global_step=None, walltime=None,
dataformats='NCHW')
"""Finetune/evaluate."""
finetune(
train_valid_datasets_provider,
model_provider,
forward_step=_cross_entropy_forward_step,
process_non_loss_data_func=dump_output_data,
end_of_epoch_callback_provider=accuracy_func_provider,
)
def main():
segmentation()
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-6
def _binarize(y_data, threshold):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data[y_data < threshold] = 0.0
y_data[y_data >= threshold] = 1.0
return y_data
def _argmax(y_data, dim):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return torch.argmax(y_data, dim).int()
def _get_tp(y_pred, y_true):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return torch.sum(y_true * y_pred).float()
def _get_fp(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return torch.sum((1 - y_true) * y_pred).float()
def _get_tn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return torch.sum((1 - y_true) * (1 - y_pred)).float()
def _get_fn(y_pred, y_true):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return torch.sum(y_true * (1 - y_pred)).float()
def _get_weights(y_true, nb_ch):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size, img_rows, img_cols = y_true.shape
pixels = batch_size * img_rows * img_cols
weights = [torch.sum(y_true==ch).item() / pixels for ch in range(nb_ch)]
return weights
class CFMatrix(object):
def __init__(self, des=None):
self.des = des
def __repr__(self):
return "ConfusionMatrix"
def __call__(self, y_pred, y_true, ignore_index, threshold=0.5):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size, img_rows, img_cols = y_pred.shape
chs = ignore_index
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = [nb_tp, nb_fp, nb_tn, nb_fn]
performs = None
else:
performs = torch.zeros(chs, 4).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_false_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_false_ch[torch.logical_and((y_true != ch), (y_true != ignore_index))] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = torch.sum(y_false_ch * y_pred_ch).float()
nb_tn = torch.sum(y_false_ch * (1 - y_pred_ch)).float()
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch), :] = torch.FloatTensor([nb_tp, nb_fp, nb_tn, nb_fn])
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class OAAcc(object):
def __init__(self, des="Overall Accuracy"):
self.des = des
def __repr__(self):
return "OAcc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
nb_tp_tn = torch.sum(y_true == y_pred).float()
mperforms = nb_tp_tn / (batch_size * img_rows * img_cols)
performs = None
return mperforms, performs
class Precision(object):
def __init__(self, des="Precision"):
self.des = des
def __repr__(self):
return "Prec"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fp + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fp + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Recall(object):
def __init__(self, des="Recall"):
self.des = des
def __repr__(self):
return "Reca"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
mperforms = nb_tp / (nb_tp + nb_fn + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
performs[int(ch)] = nb_tp / (nb_tp + nb_fn + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class F1Score(object):
def __init__(self, des="F1Score"):
self.des = des
def __repr__(self):
return "F1Sc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
mperforms = 2 * _precision * _recall / (_precision + _recall + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
_precision = nb_tp / (nb_tp + nb_fp + esp)
_recall = nb_tp / (nb_tp + nb_fn + esp)
performs[int(ch)] = 2 * _precision * \
_recall / (_precision + _recall + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Kappa(object):
def __init__(self, des="Kappa"):
self.des = des
def __repr__(self):
return "Kapp"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
nb_tp = _get_tp(y_pred, y_true)
nb_fp = _get_fp(y_pred, y_true)
nb_tn = _get_tn(y_pred, y_true)
nb_fn = _get_fn(y_pred, y_true)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn) +
(nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
mperforms = (Po - Pe) / (1 - Pe + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
nb_tp = _get_tp(y_pred_ch, y_true_ch)
nb_fp = _get_fp(y_pred_ch, y_true_ch)
nb_tn = _get_tn(y_pred_ch, y_true_ch)
nb_fn = _get_fn(y_pred_ch, y_true_ch)
nb_total = nb_tp + nb_fp + nb_tn + nb_fn
Po = (nb_tp + nb_tn) / nb_total
Pe = ((nb_tp + nb_fp) * (nb_tp + nb_fn)
+ (nb_fn + nb_tn) * (nb_fp + nb_tn)) / (nb_total**2)
performs[int(ch)] = (Po - Pe) / (1 - Pe + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class Jaccard(object):
def __init__(self, des="Jaccard"):
self.des = des
def __repr__(self):
return "Jacc"
def __call__(self, y_pred, y_true, threshold=0.5):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size, chs, img_rows, img_cols = y_true.shape
device = y_true.device
if chs == 1:
y_pred = _binarize(y_pred, threshold)
y_true = _binarize(y_true, threshold)
_intersec = torch.sum(y_true * y_pred).float()
_sum = torch.sum(y_true + y_pred).float()
mperforms = _intersec / (_sum - _intersec + esp)
performs = None
else:
y_pred = _argmax(y_pred, 1)
y_true = _argmax(y_true, 1)
performs = torch.zeros(chs, 1).to(device)
weights = _get_weights(y_true, chs)
for ch in range(chs):
y_true_ch = torch.zeros(batch_size, img_rows, img_cols)
y_pred_ch = torch.zeros(batch_size, img_rows, img_cols)
y_true_ch[y_true == ch] = 1
y_pred_ch[y_pred == ch] = 1
_intersec = torch.sum(y_true_ch * y_pred_ch).float()
_sum = torch.sum(y_true_ch + y_pred_ch).float()
performs[int(ch)] = _intersec / (_sum - _intersec + esp)
mperforms = sum([i*j for (i, j) in zip(performs, weights)])
return mperforms, performs
class MSE(object):
def __init__(self, des="Mean Square Error"):
self.des = des
def __repr__(self):
return "MSE"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
return torch.mean((y_pred - y_true) ** 2)
class PSNR(object):
def __init__(self, des="Peak Signal to Noise Ratio"):
self.des = des
def __repr__(self):
return "PSNR"
def __call__(self, y_pred, y_true, dim=1, threshold=None):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if threshold:
y_pred = _binarize(y_pred, threshold)
mse = torch.mean((y_pred - y_true) ** 2)
return 10 * torch.log10(1 / mse)
class SSIM(object):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def __init__(self, des="structural similarity index"):
self.des = des
def __repr__(self):
return "SSIM"
def gaussian(self, w_size, sigma):
gauss = torch.Tensor([math.exp(-(x - w_size//2)**2/float(2*sigma**2)) for x in range(w_size)])
return gauss/gauss.sum()
def create_window(self, w_size, channel=1):
_1D_window = self.gaussian(w_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, w_size, w_size).contiguous()
return window
def __call__(self, y_pred, y_true, w_size=11, size_average=True, full=False):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if torch.max(y_pred) > 128:
max_val = 255
else:
max_val = 1
if torch.min(y_pred) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
padd = 0
(_, channel, height, width) = y_pred.size()
window = self.create_window(w_size, channel=channel).to(y_pred.device)
mu1 = F.conv2d(y_pred, window, padding=padd, groups=channel)
mu2 = F.conv2d(y_true, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(y_pred * y_pred, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(y_true * y_true, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(y_pred * y_true, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
class AE(object):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def __init__(self, des='average Angular Error'):
self.des = des
def __repr__(self):
return "AE"
def __call__(self, y_pred, y_true):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP = torch.sum(y_pred * y_true, dim=1)
Norm_pred = torch.sqrt(torch.sum(y_pred * y_pred, dim=1))
Norm_true = torch.sqrt(torch.sum(y_true * y_true, dim=1))
ae = 180 / math.pi * torch.acos(dotP / (Norm_pred * Norm_true + eps))
return ae.mean(1).mean(1)
if __name__ == "__main__":
for ch in [3, 1]:
batch_size, img_row, img_col = 1, 224, 224
y_true = torch.rand(batch_size, ch, img_row, img_col)
noise = torch.zeros(y_true.size()).data.normal_(0, std=0.1)
y_pred = y_true + noise
for cuda in [False, True]:
if cuda:
y_pred = y_pred.cuda()
y_true = y_true.cuda()
print('#'*20, 'Cuda : {} ; size : {}'.format(cuda, y_true.size()))
########### similarity metrics
metric = MSE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = PSNR()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = SSIM()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = LPIPS(cuda)
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
metric = AE()
acc = metric(y_pred, y_true).item()
print("{} ==> {}".format(repr(metric), acc))
########### accuracy metrics
metric = OAAcc()
maccu, accu = metric(y_pred, y_true)
print('mAccu:', maccu, 'Accu', accu)
metric = Precision()
mprec, prec = metric(y_pred, y_true)
print('mPrec:', mprec, 'Prec', prec)
metric = Recall()
mreca, reca = metric(y_pred, y_true)
print('mReca:', mreca, 'Reca', reca)
metric = F1Score()
mf1sc, f1sc = metric(y_pred, y_true)
print('mF1sc:', mf1sc, 'F1sc', f1sc)
metric = Kappa()
mkapp, kapp = metric(y_pred, y_true)
print('mKapp:', mkapp, 'Kapp', kapp)
metric = Jaccard()
mjacc, jacc = metric(y_pred, y_true)
print('mJacc:', mjacc, 'Jacc', jacc)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.module import MegatronModule
from megatron.model.vision.utils import resize
class SetrSegmentationHead(MegatronModule):
def __init__(self, hidden_size, num_classes):
super(SetrSegmentationHead, self).__init__()
args = get_args()
self.hidden_size = hidden_size
self.num_classes = num_classes
self.img_h = args.img_h
self.img_w = args.img_w
self.patch_dim = args.patch_dim
self.layernorm = LayerNorm(hidden_size, eps=args.layernorm_epsilon)
self.conv_0 = torch.nn.Conv2d(hidden_size, hidden_size,
1, 1, bias=False)
self.norm_0 = apex.parallel.SyncBatchNorm(hidden_size)
self.conv_1 = torch.nn.Conv2d(hidden_size, num_classes, 1, 1)
def to_2D(self, x):
n, hw, c = x.shape
h = self.img_h // self.patch_dim
w = self.img_w // self.patch_dim
assert(hw == h * w)
x = x.transpose(1, 2).reshape(n, c, h, w)
return x
def forward(self, hidden_states):
# [b c h w]
hidden_states = self.layernorm(hidden_states)
hidden_states = self.to_2D(hidden_states)
hidden_states = self.conv_0(hidden_states)
hidden_states = self.norm_0(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.conv_1(hidden_states)
# [b c h w]
result = F.interpolate(hidden_states,
size=(self.img_h, self.img_w),
mode='bilinear')
return result
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class SegformerSegmentationHead(MegatronModule):
def __init__(self, feature_strides, in_channels,
embedding_dim, dropout_ratio):
super(SegformerSegmentationHead, self).__init__()
assert len(feature_strides) == len(in_channels)
assert min(feature_strides) == feature_strides[0]
args = get_args()
self.feature_strides = feature_strides
self.in_channels = in_channels
self.embedding_dim = embedding_dim
self.num_classes = args.num_classes
self.dropout_ratio = dropout_ratio
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = \
self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels,
embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels,
embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels,
embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels,
embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4,
self.embedding_dim, 1, 1)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(self.dropout_ratio)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim,
self.num_classes,
kernel_size=1)
def forward(self, inputs):
c1, c2, c3, c4 = inputs
############## MLP decoder on C1-C4 ###########
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = self.conv_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
return x
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import math
import einops
import torch
import apex
import torch.nn.functional as F
from megatron import get_args
from megatron.model.module import MegatronModule
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3, mit_b5
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead
class SetrSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SetrSegmentationModel, self).__init__()
args = get_args()
assert post_process & pre_process
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = VitBackbone(
pre_process=pre_process,
post_process=post_process,
class_token=False,
post_layer_norm=False,
drop_path_rate=0.1
)
self.head = SetrSegmentationHead(
self.hidden_size,
self.num_classes
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
result_final = self.head(hidden_states)
return result_final
class SegformerSegmentationModel(MegatronModule):
def __init__(self,
num_classes,
pre_process=True,
post_process=True):
super(SegformerSegmentationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
self.backbone = mit_b5()
self.head = SegformerSegmentationHead(
feature_strides=[4, 8, 16, 32],
in_channels=[64, 128, 320, 512],
embedding_dim=768,
dropout_ratio=0.1
)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
# [b hw c]
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment