import numpy as np
from . common import HFTask
from lm_eval.base import rf, mean

"""
This evaluation of Winogrande uses partial evaluation as described by
Trinh & Le in Simple Method for Commonsense Reasoning (2018).
Reference: https://arxiv.org/abs/1806.02847
"""


class Winogrande(HFTask):
    DATASET_PATH = "winogrande"
    DATASET_NAME = "winogrande_xl"

    def has_training_docs(self):
        return True

    def has_validation_docs(self):
        return True

    def has_test_docs(self):
        return True

    def fewshot_description(self):
        # TODO: redo description
        return "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in."

    @classmethod
    def partial_context(cls, doc):
        # Substitute the pronoun in the sentence with each candidate choice
        # and ignore everything after.
        pronoun_loc = doc["sentence"].index("_")
        context1 = doc["sentence"][:pronoun_loc] + doc["option1"]
        context2 = doc["sentence"][:pronoun_loc] + doc["option2"]
        return context1, context2

    @classmethod
    def partial_target(cls, doc):
        # The target is everything after the document specified pronoun.
        pronoun_loc = doc["sentence"].index("_") + 1
        return doc["sentence"][pronoun_loc:].strip()

    def doc_to_text(self, doc):
        context1, context2 = self.partial_context(doc)
        return context1 + '\n' + context2 + '\n'

    def doc_to_target(self, doc):
        return self.partial_target(doc)

    def construct_requests(self, doc, ctx):
        """ Uses RequestFactory to construct Requests and returns an iterable of 
        Requests which will be sent to the LM.

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param ctx: str
            The context string, generated by fewshot_context. This includes the natural 
            language description, as well as the few shot examples, and the question
            part of the document for `doc`. 
        """
        target = self.partial_target(doc)
        context1, context2 = self.partial_context(doc)
        ll_context1, _ = rf.loglikelihood(context1, " " + target)
        ll_context2, _ = rf.loglikelihood(context2, " " + target)
        return ll_context1, ll_context2

    def process_results(self, doc, results):
        """Take a single document and the LM results and evaluates, returning a 
        dict where keys are the names of submetrics and values are the values of 
        the metric for that one document

        :param doc:
            The document as returned from training_docs, validation_docs, or test_docs.
        :param results:
            The results of the requests created in construct_requests.
        """
        answer = int(doc["answer"]) - 1  # `- 1` b/c doc["answer"] ∈ {'1', '2'}
        return {
            "acc": np.argmax(results) == answer
        }

    def aggregation(self):
        """
        :returns: {str: [float] -> float}
            A dictionary where keys are the names of submetrics and values are 
            functions that aggregate a list of metrics
        """
        return {
            "acc": mean
        }

    def higher_is_better(self):
        """
        :returns: {str: bool}
            A dictionary where keys are the names of submetrics and values are 
            whether a higher value of the submetric is better
        """
        return {
            "acc": True
        }
