import random
import threading
import time
import unittest
from types import SimpleNamespace

import requests

import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
    DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
    DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
    DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


class TestEAGLEEngine(unittest.TestCase):

    def test_eagle_accuracy(self):
        prompt1 = "Today is a sunny day and I like"
        sampling_params1 = {"temperature": 0, "max_new_tokens": 8}

        # Get the reference output
        ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
        ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
        ref_engine.shutdown()

        # Test cases with different configurations
        configs = [
            # Original config
            {
                "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
                "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
                "speculative_algorithm": "EAGLE",
                "speculative_num_steps": 5,
                "speculative_eagle_topk": 8,
                "speculative_num_draft_tokens": 64,
                "mem_fraction_static": 0.7,
            },
            # Config with CUDA graph disabled
            {
                "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
                "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
                "speculative_algorithm": "EAGLE",
                "speculative_num_steps": 5,
                "speculative_eagle_topk": 8,
                "speculative_num_draft_tokens": 64,
                "mem_fraction_static": 0.7,
                "disable_cuda_graph": True,
            },
        ]

        for config in configs:
            # Launch EAGLE engine
            engine = sgl.Engine(**config)

            # Case 1: Test the output of EAGLE engine is the same as normal engine
            out1 = engine.generate(prompt1, sampling_params1)["text"]
            print(f"{out1=}, {ref_output=}")
            self.assertEqual(out1, ref_output)

            # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
            prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
            sampling_params2 = {
                "temperature": 0,
                "max_new_tokens": 1024,
                "skip_special_tokens": False,
            }

            tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
            out2 = engine.generate(prompt2, sampling_params2)["text"]
            print(f"{out2=}")
            tokens = tokenizer.encode(out2, truncation=False)
            assert tokenizer.eos_token_id not in tokens

            # Case 3: Batched prompts
            prompts = [
                "Hello, my name is",
                "The president of the United States is",
                "The capital of France is",
                "The future of AI is",
            ]
            sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
            outputs = engine.generate(prompts, sampling_params3)
            for prompt, output in zip(prompts, outputs):
                print("===============================")
                print(f"Prompt: {prompt}\nGenerated text: {output['text']}")

            # Shutdown the engine
            engine.shutdown()


prompts = [
    "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
    '[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
    "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
    "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
    "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]


class TestEAGLEServer(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.process = popen_launch_server(
            DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
            cls.base_url,
            timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
            other_args=[
                "--speculative-algorithm",
                "EAGLE",
                "--speculative-draft-model-path",
                DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
                "--speculative-num-steps",
                "5",
                "--speculative-eagle-topk",
                "8",
                "--speculative-num-draft-tokens",
                "64",
                "--mem-fraction-static",
                "0.7",
            ],
        )

    @classmethod
    def tearDownClass(cls):
        kill_process_tree(cls.process.pid)

    def send_request(self):
        time.sleep(random.uniform(0, 2))
        for prompt in prompts:
            url = self.base_url + "/generate"
            data = {
                "text": prompt,
                "sampling_params": {
                    "temperature": 0,
                    "max_new_tokens": 1024,
                },
            }
            response = requests.post(url, json=data)
            assert response.status_code == 200

    def send_requests_abort(self):
        for prompt in prompts:
            try:
                time.sleep(random.uniform(0, 2))
                url = self.base_url + "/generate"
                data = {
                    "model": "base",
                    "text": prompt,
                    "sampling_params": {
                        "temperature": 0,
                        "max_new_tokens": 1024,
                    },
                }
                # set timeout = 1s,mock disconnected
                requests.post(url, json=data, timeout=1)
            except Exception as e:
                print(e)
                pass

    def test_request_abort(self):
        concurrency = 4
        threads = [
            threading.Thread(target=self.send_request) for _ in range(concurrency)
        ] + [
            threading.Thread(target=self.send_requests_abort)
            for _ in range(concurrency)
        ]
        for worker in threads:
            worker.start()
        for p in threads:
            p.join()

    def test_gsm8k(self):
        args = SimpleNamespace(
            num_shots=5,
            data_path=None,
            num_questions=200,
            max_new_tokens=512,
            parallel=128,
            host="http://127.0.0.1",
            port=int(self.base_url.split(":")[-1]),
        )
        metrics = run_eval(args)
        print(f"{metrics=}")

        self.assertGreater(metrics["accuracy"], 0.20)


if __name__ == "__main__":
    unittest.main()
