import re
from datasets import load_dataset, Dataset

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")

def extract_deepseek_r1_answer(text) -> str | None:
    words_to_check = ["applied_math", "Advanced-Math", "GSM8K_zh", 'EduChat-Math']
    pattern = r'\b(' + '|'.join(map(re.escape, words_to_check)) + r')\b'
    has_match = bool(re.search(pattern, text['repo_name'], flags=re.IGNORECASE))
    if has_match:
        pattern = r"\\boxed\{(.*)\}"
        match = re.search(pattern, text['output'])
        if match:
            return match.group(1)
        else:
            return None
    else:
        return None

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(dataset='openai/gsm8k', split="train") -> Dataset:
    data = load_dataset(dataset, 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    },
    num_proc=16,
    remove_columns=["question"]) # type: ignore
    data = data.filter(lambda x: x['answer'] is not None, num_proc=16)
    # print("---", data[0])
    return data # type: ignore

def get_deepseek_r1_questions(dataset='Congliu/Chinese-DeepSeek-R1-Distill-data-110k-SFT', split="train") -> Dataset:
    data = load_dataset(dataset)[split] # type: ignore
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['instruction']}
        ],
        'answer': extract_deepseek_r1_answer(x)
    },
    num_proc=16, # type: ignore
    remove_columns=["instruction", "output", "repo_name", "prompt_tokens_len", "input", "reasoning_content_tokens_len", "score", "content_tokens_len"],
    )

    data = data.filter(lambda x: x['answer'] is not None, num_proc=32) # type: ignore
    print("GET {} data in Chinese-DeepSeek-R1-Distill-data-110k-SFT".format(len(data)))
    return data # type: ignore

def get_hiyoga(dataset='hiyouga/math12k', split='train')-> Dataset:
    data = load_dataset(dataset)[split] # type: ignore
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['problem']}
        ],
        'answer': x['answer']
    },
    remove_columns=["problem"],
    num_proc=16,
    )
    data = data.filter(lambda x: x['answer'] is not None, num_proc=16)
    # print(len(data))
    return data # type: ignore

def get_unsloth_openmath(dataset="unsloth/OpenMathReasoning-mini", split='cot') -> Dataset:
    data = load_dataset(dataset)[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['problem']}
        ],
        'answer': x['expected_answer']
    },
    remove_columns=["expected_answer", "problem_type", "problem_source", "generation_model", "pass_rate_72b_tir", "generated_solution", "inference_mode", "problem",],
    num_proc=16,
    )
    data = data.filter(lambda x: x['answer'] is not None, num_proc=16)

    # print("len of unsloth", len(data))
    # print("=====", data)
    return data # type: ignore


def get_openr1_dapo_math(dataset="open-r1/DAPO-Math-17k-Processed", split="train") -> Dataset:
    data = load_dataset(dataset, "all")[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt']}
        ],
        'answer': x['solution']
    },
    remove_columns=["solution", "data_source", "source_prompt", "ability", "reward_model", "extra_info"],
    num_proc=16,
    )
    data = data.filter(lambda x: x['answer'] is not None, num_proc=16)

    return data # type: ignore
