"vscode:/vscode.git/clone" did not exist on "7dca70049566b5b1c55cbd67e1cb191729a98152"
Commit 2eaf6c79 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

cleaning the code

parent 7a0710ec
...@@ -26,8 +26,8 @@ class IndexBuilder(object): ...@@ -26,8 +26,8 @@ class IndexBuilder(object):
self.evidence_embedder_obj = None self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model args.biencoder_shared_query_context_model
self.pre_process = True #self.pre_process = True
self.post_process = True #self.post_process = True
# need to know whether we're using a REALM checkpoint (args.load) # need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint # or ICT checkpoint
...@@ -46,7 +46,7 @@ class IndexBuilder(object): ...@@ -46,7 +46,7 @@ class IndexBuilder(object):
""" """
Load the necessary attributes: model, dataloader and empty BlockData Load the necessary attributes: model, dataloader and empty BlockData
""" """
args = get_args() #args = get_args()
only_context_model = True only_context_model = True
if self.biencoder_shared_query_context_model: if self.biencoder_shared_query_context_model:
only_context_model = False only_context_model = False
...@@ -56,7 +56,7 @@ class IndexBuilder(object): ...@@ -56,7 +56,7 @@ class IndexBuilder(object):
#model = get_model(biencoder_model_provider) #model = get_model(biencoder_model_provider)
model = get_model(get_model_provider(only_context_model=only_context_model, model = get_model(get_model_provider(only_context_model=only_context_model,
biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \
...@@ -103,12 +103,12 @@ class IndexBuilder(object): ...@@ -103,12 +103,12 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'): while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
counter = 0 #counter = 0
start_time = time.time() #start_time = time.time()
cur_time = start_time #cur_time = start_time
while True: while True:
#start_time = time.time() #start_time = time.time()
t1 = time.time() #t1 = time.time()
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \ row_id, context_tokens, context_mask, context_types, \
...@@ -118,7 +118,7 @@ class IndexBuilder(object): ...@@ -118,7 +118,7 @@ class IndexBuilder(object):
break break
#print_rank_0("get batch time {}".format(cur_time - time.time())) #print_rank_0("get batch time {}".format(cur_time - time.time()))
t2 = time.time() #t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage # TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool assert context_mask.dtype == torch.bool
...@@ -129,17 +129,17 @@ class IndexBuilder(object): ...@@ -129,17 +129,17 @@ class IndexBuilder(object):
context_logits = detach(context_logits) context_logits = detach(context_logits)
row_id = detach(row_id) row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time())) #print_rank_0("embed text {}".format(cur_time - time.time()))
t3 = time.time() #t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id)) self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time())) #print_rank_0("add block time {}".format(cur_time - time.time()))
t4 = time.time() #t4 = time.time()
counter += 1 #counter += 1
if counter % 1000 == 0: #if counter % 1000 == 0:
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) # print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) # print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time = time.time() # cur_time = time.time()
# This process signals to finalize its shard and then synchronize with # This process signals to finalize its shard and then synchronize with
# the other processes # the other processes
self.evidence_embedder_obj.save_shard() self.evidence_embedder_obj.save_shard()
......
...@@ -15,17 +15,17 @@ from megatron.model.utils import init_method_normal ...@@ -15,17 +15,17 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def get_model_provider(only_query_model=False, only_context_model=False, def get_model_provider(only_query_model=False, only_context_model=False,
biencoder_shared_query_context_model=False): biencoder_shared_query_context_model=False):
def model_provider(pre_process=True, post_process=True): def model_provider(pre_process=True, post_process=True):
"""Build the model.""" """Build the model."""
print_rank_0('building Bienoder model ...') print_rank_0('building Bienoder model ...')
model = biencoder_model_provider(only_query_model=only_query_model, model = biencoder_model_provider(only_query_model=only_query_model,
only_context_model = only_context_model, only_context_model = only_context_model,
biencoder_shared_query_context_model = \ biencoder_shared_query_context_model = \
biencoder_shared_query_context_model, biencoder_shared_query_context_model,
pre_process=True, post_process=True) pre_process=True, post_process=True)
return model return model
...@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False, ...@@ -33,21 +33,12 @@ def get_model_provider(only_query_model=False, only_context_model=False,
return model_provider return model_provider
#def biencoder_model_provider(pre_process=True,
# post_process=True):
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
biencoder_shared_query_context_model=False, biencoder_shared_query_context_model=False,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
"""Build the model.""" """Build the model."""
#args = get_args()
#biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
#only_context_model = args.only_context_model
#only_query_model = args.only_query_model
assert mpu.get_tensor_model_parallel_world_size() == 1 and \ assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \ mpu.get_pipeline_model_parallel_world_size() == 1, \
...@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False, ...@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model, only_query_model=only_query_model,
only_context_model=only_context_model, only_context_model=only_context_model,
biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
biencoder_shared_query_context_model, biencoder_shared_query_context_model,
pre_process=pre_process, pre_process=pre_process,
post_process=post_process) post_process=post_process)
...@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule): ...@@ -114,9 +105,9 @@ class BiEncoderModel(MegatronModule):
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
#this is just a placeholder and will be needed when model # this is just a placeholder and will be needed when model
#parallelism will be used # parallelism will be used
#self.language_model.set_input_tensor(input_tensor) # self.language_model.set_input_tensor(input_tensor)
return return
def forward(self, query_tokens, query_attention_mask, query_types, def forward(self, query_tokens, query_attention_mask, query_types,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
......
...@@ -36,7 +36,7 @@ def pretrain_ict_model_provider(): ...@@ -36,7 +36,7 @@ def pretrain_ict_model_provider():
#args.only_context_model = False #args.only_context_model = False
#args.only_query_model = False #args.only_query_model = False
#model = biencoder_model_provider() #model = biencoder_model_provider()
model = biencoder_model_provider( model = biencoder_model_provider(
only_context_model=False, only_context_model=False,
only_query_model=False, only_query_model=False,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""Finetune utilities.""" """Finetune utilities."""
from functools import partial from functools import partial
import sys
import torch import torch
...@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model): ...@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels) return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None): task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" """Data loader. Note that batch-size is the local (per GPU) batch-size."""
...@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue continue
# Set to zero so the next epoch does not skip any batches. # Set to zero so the next epoch does not skip any batches.
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
...@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, False)
#if iteration == 600:
# sys.exit()
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
...@@ -89,8 +89,8 @@ def get_tasks_args(parser): ...@@ -89,8 +89,8 @@ def get_tasks_args(parser):
# help="Av.rank validation: batch size to process passages") # help="Av.rank validation: batch size to process passages")
#group.add_argument("--val-av-rank-max-qs", type=int, default=10000, #group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
# help="Av.rank validation: max num of questions") # help="Av.rank validation: max num of questions")
return parser return parser
......
...@@ -15,18 +15,6 @@ ...@@ -15,18 +15,6 @@
"""Main tasks functionality.""" """Main tasks functionality."""
import os
import sys
#sys.path.append(
# os.path.abspath(
# os.path.join(
# os.path.join(os.path.dirname(__file__), os.path.pardir),
# os.path.pardir,
# )
# )
#)
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
...@@ -35,30 +23,23 @@ def main(): ...@@ -35,30 +23,23 @@ def main():
""" """
Main program Main program
""" """
#initialize_megatron(extra_args_provider=None,
# args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset """
- Include all args needed for initial model specification Create a BlockData data structure by running an IndexBuilder over an
ICT Dataset and then evaluate on NQ task
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
""" """
#print_rank_0("Starting index builder!") print_rank_0("Starting index builder!")
index_builder = IndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!") print_rank_0("Build and save indices: done!")
print_rank_0("Starting evaluations!")
# Set up the model and evaluator # Set up the model and evaluator
evaluator = ORQAEvaluator() evaluator = ORQAEvaluator()
...@@ -68,4 +49,4 @@ def main(): ...@@ -68,4 +49,4 @@ def main():
if args.qa_data_test is not None: if args.qa_data_test is not None:
evaluator.evaluate(args.qa_data_test, "TEST") evaluator.evaluate(args.qa_data_test, "TEST")
...@@ -47,10 +47,9 @@ class ORQAEvaluator(object): ...@@ -47,10 +47,9 @@ class ORQAEvaluator(object):
#args.only_query_model = only_query_model #args.only_query_model = only_query_model
#args.only_context_model = False #args.only_context_model = False
model = get_model(get_model_provider(only_query_model=only_query_model, model = get_model(get_model_provider(only_query_model=only_query_model,
biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\
#model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\
# only_query_model, biencoder_shared_query_context_model=\ # only_query_model, biencoder_shared_query_context_model=\
......
...@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, ...@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
return enc_ids, tokentypes_enc, pad_mask return enc_ids, tokentypes_enc, pad_mask
def build_sample(query_ids, query_types, query_pad_mask, def build_sample(query_ids, query_types, query_pad_mask,
ctx_ids, ctx_types, ctx_pad_mask, answers, ctx_ids, ctx_types, ctx_pad_mask, answers,
neg_ctx_id_list=None, neg_ctx_types_list=None, neg_ctx_id_list=None, neg_ctx_types_list=None,
include_neg=False): include_neg=False):
"""Convert to numpy and return a sample consumed by the batch producer.""" """Convert to numpy and return a sample consumed by the batch producer."""
...@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset): ...@@ -295,5 +295,3 @@ class NQSupervisedDataset(OpenRetrievalAbstractDataset):
print_rank_0(' >> processed {} samples.'.format(len(samples))) print_rank_0(' >> processed {} samples.'.format(len(samples)))
return samples return samples
...@@ -34,7 +34,6 @@ def task_collate_fn(batch_data): ...@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
for d in batch_data: for d in batch_data:
for k, v in d.items(): for k, v in d.items():
tensorized.setdefault(k, []).append(v) tensorized.setdefault(k, []).append(v)
# assert len(tensorized) == 12
tensorized['query'] = torch.LongTensor(tensorized['query']) tensorized['query'] = torch.LongTensor(tensorized['query'])
tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
...@@ -90,8 +89,6 @@ def process_batch(batch): ...@@ -90,8 +89,6 @@ def process_batch(batch):
neg_context_tokens, neg_context_mask, neg_context_types, reference neg_context_tokens, neg_context_mask, neg_context_types, reference
def accuracy_func_provider(single_dataset_provider, rank0sampler=False): def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies.""" """Provide function that calculates accuracies."""
args = get_args() args = get_args()
...@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False): ...@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
args.eval_micro_batch_size, args.eval_micro_batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
drop_last=drop_last, drop_last=drop_last,
task_collate_fn=task_collate_fn) task_collate_fn=task_collate_fn)
#shuffle=False,
#rank0sampler=rank0sampler)
dataloaders = (dataset.dataset_name, dataloader) dataloaders = (dataset.dataset_name, dataloader)
def metrics_func(model, epoch, output_predictions=False): def metrics_func(model, epoch, output_predictions=False):
...@@ -197,7 +192,7 @@ def retrieval_loss(model, dataloader): ...@@ -197,7 +192,7 @@ def retrieval_loss(model, dataloader):
losses = average_losses_across_data_parallel_group([rank, \ losses = average_losses_across_data_parallel_group([rank, \
*topk_accs]) *topk_accs])
# create stats_dict with retrieval loss and all specified # create stats_dict with retrieval loss and all specified
# top-k accuracies # top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, losses[1:])} zip(args.retriever_report_topk_accuracies, losses[1:])}
......
...@@ -22,27 +22,21 @@ import math ...@@ -22,27 +22,21 @@ import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args, get_timers, get_tokenizer
from megatron import get_timers from megatron import mpu, print_rank_0
from megatron import get_tokenizer from megatron.indexer import IndexBuilder
from megatron import mpu
from megatron import print_rank_0
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider from megatron.utils import average_losses_across_data_parallel_group
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from pretrain_ict import get_group_world_size_rank from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
from tasks.orqa.evaluate_utils import ORQAEvaluator from tasks.orqa.evaluate_utils import ORQAEvaluator
from megatron.indexer import IndexBuilder
def orqa(Dataset): # , name_from_datapath_func): def orqa(Dataset):
def cross_entropy_forward_step(batch, model): def cross_entropy_forward_step(batch, model):
"""Simple forward step with cross-entropy loss.""" """Simple forward step with cross-entropy loss."""
args = get_args()
timers = get_timers() timers = get_timers()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -73,17 +67,15 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -73,17 +67,15 @@ def orqa(Dataset): # , name_from_datapath_func):
context_types = torch.cat([context_types, neg_context_types]) context_types = torch.cat([context_types, neg_context_types])
# Forward model. # Forward model.
#query_logits, context_logits = model(query_tokens, query_mask, output_tensor = model(query_tokens, query_mask,
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
query_types, context_tokens,
context_mask, context_types) context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens) return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
#def cross_entropy_loss_func(labels, output_tensor): def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor): args = get_args()
args = get_args()
local_batch_size = query_tokens.shape[0] local_batch_size = query_tokens.shape[0]
group, rank, world_size = get_group_world_size_rank() group, rank, world_size = get_group_world_size_rank()
...@@ -184,12 +176,9 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -184,12 +176,9 @@ def orqa(Dataset): # , name_from_datapath_func):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building retriever model for {} ...'.format(args.task)) print_rank_0('building retriever model for {} ...'.format(args.task))
#args.only_context_model=False
#args.only_query_model=False
#model = biencoder_model_provider()
model = biencoder_model_provider(only_context_model=False, model = biencoder_model_provider(only_context_model=False,
only_query_model=False, only_query_model=False,
biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model, args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process) pre_process=pre_process, post_process=post_process)
...@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -200,7 +189,6 @@ def orqa(Dataset): # , name_from_datapath_func):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
#name = name_from_datapath_func(datapath)
name = datapath[0].split('/')[-1].split('.')[0] name = datapath[0].split('/')[-1].split('.')[0]
return Dataset(name, return Dataset(name,
datapath, datapath,
...@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func): ...@@ -208,41 +196,25 @@ def orqa(Dataset): # , name_from_datapath_func):
args.retriever_seq_length, args.retriever_seq_length,
evaluate=True) evaluate=True)
#def distributed_metrics_func_provider():
def metrics_func_provider(): def metrics_func_provider():
"""Provide metrics callback function.""" """Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return accuracy_func_provider(single_dataset_provider) return accuracy_func_provider(single_dataset_provider)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate.""" """Finetune/evaluate."""
finetune(train_valid_datasets_provider, finetune(train_valid_datasets_provider,
model_provider, model_provider,
forward_step=cross_entropy_forward_step, forward_step=cross_entropy_forward_step,
end_of_epoch_callback_provider=metrics_func_provider, end_of_epoch_callback_provider=metrics_func_provider,
task_collate_fn=task_collate_fn) task_collate_fn=task_collate_fn)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def main(): def main():
args = get_args() args = get_args()
if args.task == 'RET-FINETUNE-NQ': if args.task == 'RET-FINETUNE-NQ':
from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else: else:
raise NotImplementedError('ORQA task {} is not implemented.'.format( raise NotImplementedError('ORQA task {} is not implemented.'.format(
args.task)) args.task))
orqa(Dataset) #, name_from_datapath) orqa(Dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment