import os

from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
from lm_eval.models.openai_completions import LocalChatCompletion
from lm_eval.utils import eval_logger


@register_filter("judge")
class JudgeFilter(Filter):
    """A filter that evaluates the correctness of a question-answering system's answers, using an LM Judge"""

    PROMPT = """You are an expert evaluator of question-answering systems. Your task is to determine if a given answer matches the ground truth answer in meaning and accuracy. You should respond with "yes", "no", or "unknown".

    Guidelines for evaluation:
    1. For multiple-choice questions, the answer choice letters are enough to determine correctness
    2. Focus on semantic meaning rather than exact wording
    3. Consider numerical accuracy when applicable
    4. Account for partial answers that contain the correct information plus additional details
    5. Recognize equivalent phrasings and synonyms
    6. Be lenient with minor grammatical differences
    7. For multi-part questions, all parts must be correct
    8. For questions requiring specific units, check unit correctness. However, if the answer is correct in all other aspects, you may overlook minor unit errors

    Input format:
    Question: [The question being asked]
    Answer: [The answer given by the system]
    Ground Truth: [The known correct answer]

    Your response must be exactly "yes" or "no", with no additional explanation.

    Example 1:
    Question: What is the capital of France?
    Answer: The capital city of France is Paris.
    Ground Truth: Paris
    Your response: yes

    Example 2:
    Question: How many planets are in our solar system?
    Answer: There are seven planets in our solar system.
    Ground Truth: Eight planets (Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune)
    Your response: no

    Example 3:
    Question: What is the GDP of France in 2023?\nA. 2 trillion USD\nB. 3.05 trillion USD\nC. 2.5 trillion USD
    Answer: B.
    Ground Truth: 3.05 trillion USD
    Your response: yes

    Your response must be exactly "yes" or "no", with no additional explanation!
    """

    def __init__(self, url, model, prompt=None, **kwargs) -> None:
        """
        pass a string `regex` to run `re.compile(r"regex")` on.
        `fallback` defines the output returned if no matches for the regex are located.
        """
        assert os.environ.get("AI_API_KEY") is not None, (
            "Please set the AI_API_KEY environment variable to use the JudgeFilter (can be empty string)"
        )
        eval_logger.info(
            "Pass num_concurrent=N to --metadata to set the number of concurrent requests for the JudgeFilter"
        )
        self.model = LocalChatCompletion(
            base_url=url, pretrained=model, num_concurrent=2, **kwargs
        )
        self.prompt = self.PROMPT if prompt is None else prompt

    @staticmethod
    def create_message(str) -> dict:
        return {"role": "user", "content": str}

    def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
        inputs = [
            [
                self.create_message(
                    self.PROMPT
                    + "\n\n"
                    + f"Question: {doc['question']}\nAnswer: {resp}\nGround Truth: {doc['answer']}"
                )
            ]
            for resp, doc in zip(resps, docs)
        ]

        res = self.model.simple_async_generate(inputs, gen_kwargs={})
        return [[x] for x in res]
