import re
import numpy as np
from ..base import rf, mean
from . common import HFTask


class HellaSwag(HFTask):
    DATASET_PATH = "hellaswag"
    DATASET_NAME = None

    @classmethod
    def remove_brackets(cls, text):
        """ Removes brackets from HellaSwag documents. 
        NOTE: The brackets are artifacts of the WikiHow dataset portion underlying
        HellaSwag.
        """
        text = re.sub('\[.*?\]', '', text)
        return text

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def training_docs(self):
        if self.has_training_docs():
            return self.data["train"]

    def validation_docs(self):
        if self.has_validation_docs():
            return self.data["validation"]

    def test_docs(self):
        if self.has_test_docs():
            return self.data["test"]

    def fewshot_description(self):
        return "Label for the relevant action: Sentences describing the " \
            "context, with an incomplete sentence trailing\nanswer that " \
            "plausibly completes the situation."

    def doc_to_text(self, doc):
        text = doc['activity_label'] + ': ' + doc['ctx'] + '\n'
        return self.remove_brackets(text)

    def doc_to_target(self, doc):
        letter_answer = doc['label']
        if letter_answer == '0':
            index = 0
        elif letter_answer == '1':
            index = 1
        elif letter_answer == '2':
            index = 2
        elif letter_answer == '3':
            index = 3
        else:
            raise ValueError(
                "HellaSwag from HF datasets contained an invalid answer key")
        target = doc['endings'][index]
        return self.remove_brackets(target)

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of
        Requests which will be sent to the LM.
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural
            language description, as well as the few shot examples, and the question
            part of the document for `doc`.
        """
        ll_answers = []
        for i in range(4):
            continuation = self.remove_brackets(doc['endings'][i])
            ll_answers.append(rf.loglikelihood(ctx, continuation))
        return ll_answers

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a
        dict where keys are the names of submetrics and values are the values of
        the metric for that one document
        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        gold = int(doc['label'])
        pred = np.argmax(results)
        acc = 1. if pred == gold else 0.
        return {
            "acc": acc
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are
            whether a higher value of the submetric is better
        """
        return {
            "acc": True
        }
