import argparse
from utils import load_hyperparam, convert_normal_parameter_to_int8, load_model
from model.tokenize import Tokenizer
from model.llama import *
from generate import LmGeneration_test
from time import perf_counter
import numpy as np


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--load_model_path", default=None, type=str,
                        help="Path of the input model.")
    parser.add_argument("--test_path", type=str, required=True,
                        help="Path of the testset.")
    parser.add_argument("--prediction_path", type=str, required=True,
                        help="Path of the prediction file.")
    parser.add_argument("--config_path", type=str, required=True,
                        help="Path of the config file.")
    parser.add_argument("--batch_size", type=int, default=1,
                        help="Batch size.")
    parser.add_argument("--world_size", type=int, default=1,
                        help="the number of gpus.")
    parser.add_argument("--seq_length", type=int, default=128,
                        help="Sequence length.")
    parser.add_argument("--use_int8", action="store_true")
    parser.add_argument("--top_k", type=int, default=10)
    parser.add_argument("--top_p", type=float, default=1)
    parser.add_argument("--temperature", type=float, default=0.85)
    parser.add_argument("--repetition_penalty_range", type=int, default=1024)
    parser.add_argument("--repetition_penalty_slope", type=float, default=0)
    parser.add_argument("--repetition_penalty", type=float, default=1.15)

    parser.add_argument("--spm_model_path", default=None, type=str,
                        help="Path of the sentence piece model.")

    args = parser.parse_args()

    args = load_hyperparam(args)

    args.tokenizer = Tokenizer(model_path=args.spm_model_path)
    args.vocab_size = args.tokenizer.sp_model.vocab_size()

    torch.set_default_tensor_type(torch.HalfTensor)
    model = LLaMa(args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model = load_model(model, args.load_model_path)

    model.eval()
    # use multi-gpu tensor parallel
    if args.world_size > 1:
        import tensor_parallel as tp
        gpus = ["cuda:" + str(i) for i in range(args.world_size)]
        if args.use_int8:
            model = tp.tensor_parallel(model, gpus, delay_init=True)
            model = convert_normal_parameter_to_int8(model)
        else:
            model = tp.tensor_parallel(model, gpus)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

    lm_generation = LmGeneration(model, args.tokenizer)
    prompts = []
    with open(args.test_path, 'r', encoding='utf-8') as f:
        for line in f:
            prompts.append(line)
    prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

    
    with torch.no_grad():
        latencies = []
        for _ in range(2):
            _ =  lm_generation.generate(args, prompt_tokens)
        for _ in range(10):
            start_time = perf_counter()
            tokens = lm_generation.generate(args, prompt_tokens)
            latency = perf_counter() - start_time
            latencies.append(latency)
        decoder = []
        for i, t in enumerate(tokens.tolist()):
            t = t[: args.seq_length]
            try:
                t = t[: t.index(args.tokenizer.pad_id)]
                t = t[: t.index(args.tokenizer.eos_id)]
            except ValueError:
                pass
            decoder.append(args.tokenizer.decode(t))
        time_avg_ms = 1000 * np.mean(latencies) # 延时均值
        time_std_ms = 1000 * np.std(latencies) # 延时方差
        time_p95_ms = 1000 * np.percentile(latencies,95) # 延时的95分位数
        print(f"P95延时 (ms) - {time_p95_ms}; 平均延时 (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f};")
        print(f"字平均延时 (ms) - {time_avg_ms/(len(decoder[0])-len(prompts[0]))}; ")
        print(decoder[0])
