import numpy as np
import json
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef
from tqdm import auto as tqdm_lib
from . common import HFTask, simple_accuracy_metric, yesno
from pathlib import Path
from ..base import Dataset

class DROP(Dataset):
    DATAFOLDER = Path(__file__).parent / "../../data/drop"
    
    def __init__(self):
        super().__init__()

    def has_training_docs(self):
        """Whether the task has a training set"""
        return True
    
    def has_validation_docs(self):
        """Whether the task has a validation set"""
        return True

    def has_test_docs(self):
        """Whether the task has a test set"""
        return False

    def training_docs(self):
        docs = json.load(open(self.DATAFOLDER / 'drop_dataset_train.json'))
        return [docs[k] for k in docs.keys()]


    def validation_docs(self):
        docs = json.load(open(self.DATAFOLDER / 'drop_dataset_dev.json'))
        return [docs[k] for k in docs.keys()]
    
    def test_docs(self):
        pass
    
    def doc_to_text(self, doc, include_target=True):
        doctext = "Passage: {}\n".format(doc["passage"])
        qa_texts = []
        for pair in doc["qa_pairs"]:
            text = ''.join(['Question: ', pair['question'],'\nAnswer: '])
            if include_target:
                def get_answer(ans_dict):
                    if ans_dict['number'] != '':
                        return ans_dict['number']
                    if ans_dict['spans'] != []:
                        if len(ans_dict['spans']) > 0:
                            return ', '.join(ans_dict['spans'])
                        return ans_dict['spans'][0]
                    return ' '.join([ans_dict['date']['day'], 
                                     ans_dict['date']['month'], 
                                     ans_dict['date']['year']]).strip() 
                text = ''.join([text, get_answer(pair['answer'])])
            qa_texts.append(text)
        return ''.join([doctext, '\n'.join(qa_texts)])

    def fewshot_description(self):
        # TODO: figure out description
        return ""

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
        # TODO: implement evaluation.
        raise NotImplementedError('Evaluation not implemented')
    
    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        # TODO: implement evaluation.
        raise NotImplementedError('Evaluation not implemented')

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
        # TODO: implement evaluation.
        raise NotImplementedError('Evaluation not implemented')

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        # TODO: implement evaluation.
        raise NotImplementedError('Evaluation not implemented')