import os
import json
import argparse

from tqdm import tqdm
import numpy as np
import torch
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

import torch.nn as nn
from bert4torch.layers import CRF
from bert4torch.models import build_transformer_model, BaseModel


class Model(BaseModel):
    def __init__(self, config_path):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=None, segment_vocab_size=0)
        # embedding_dims:768, len_categories: 7
        self.fc = nn.Linear(768, 7)  # 包含首尾
        self.crf = CRF(7)

    def forward(self, token_ids):
        sequence_output = self.bert([token_ids])  # [btz, seq_len, hdsz]
        emission_score = self.fc(sequence_output)  # [btz, seq_len, tag_size]
        attention_mask = token_ids.gt(0).long()
        return emission_score, attention_mask


def build_model(config_path, checkpoint_path):
    model = Model(config_path).to("cpu")
    model.load_weights(checkpoint_path, strict=False)
    return model


def pad_data(data, seq=256):
    if len(data.shape) == 1:
        return np.pad(data, ((0, seq-data.shape[0])),
                      "constant", constant_values=(0))
    
    elif len(data.shape) == 2:
        return np.pad(data, ((0, 0), (0, seq-data.shape[1])),
                      "constant", constant_values=(0))
    
    else:
        # shape(bs, seq, len(categories))
        return np.pad(data, ((0, 0), (0, seq-data.shape[1]), (0, 0)),
                      "constant", constant_values=(0))


def pad_data_npy(path, seq=256):
    return pad_data(np.load(path), seq)


def pad_data_bin(path, output, bs, seq=256, len_catagory=7):
    data = None
    if output == "emission_score":
        data = np.fromfile(path, dtype=np.float32).reshape((bs, -1, len_catagory))
    else:
        data = np.fromfile(path, dtype=np.int64).reshape((bs, -1))
    
    return pad_data(data, seq)


def evaluate(result_dir, label_dir, bs):
    X, Y, Z = 1e-10, 1e-10, 1e-10
    X2, Y2, Z2 = 1e-10, 1e-10, 1e-10
    true_labels, true_predictions = [], []
    data_num = len(os.listdir(label_dir))

    emission_score = None
    labels = None
    attention_mask = None

    for data_idx in tqdm(range(data_num)):
        emission_score_path = [
            os.path.join(result_dir, f"{data_idx}_0.{fmt}") for fmt in ["npy", "bin"]
        ]
        if os.path.exists(emission_score_path[0]):
            emission_score = pad_data_npy(emission_score_path[0])
        else:
            print(emission_score_path[1])
            emission_score = pad_data_bin(emission_score_path[1], "emission_score", bs)

        
        attention_mask_path = [
            os.path.join(result_dir, f"{data_idx}_1.{fmt}") for fmt in ["npy", "bin"]
        ]
        if os.path.exists(attention_mask_path[0]):
            attention_mask = pad_data_npy(attention_mask_path[0])
        else:
            attention_mask = pad_data_bin(attention_mask_path[1], "attention_mask", bs)
        
        label_path = [
            os.path.join(label_dir, f"{data_idx}.{fmt}") for fmt in ["npy", "bin"]
        ]
        if os.path.exists(label_path[0]):
            labels = pad_data_npy(label_path[0])
        else:
            labels = pad_data_bin(label_path[1], "labels", bs)
            
        labels = torch.Tensor(labels)
        # mask last data
        data_mask = labels[:, 0] >= 0
        labels = labels[data_mask]
        emission_score = torch.Tensor(emission_score)[data_mask]
        attention_mask = torch.Tensor(attention_mask)[data_mask]

        scores = crf.decode(emission_score, attention_mask)

        true_label = []
        for label in labels:
            true_label += [categories_id2label[int(l)] for l in label if l != -100]

        true_labels.append(true_label)

        true_prediction = []
        for score in scores:
            true_prediction += [categories_id2label[int(p)] for p in score if p != -100]

        true_predictions.append(true_prediction)

        attention_mask = labels.gt(0)
        # token粒度
        X += (scores.eq(labels) * attention_mask).sum().item()
        Y += scores.gt(0).sum().item()
        Z += labels.gt(0).sum().item()

        # entity粒度
        entity_pred = trans_entity2tuple(scores)
        entity_true = trans_entity2tuple(labels)
        X2 += len(entity_pred.intersection(entity_true))
        Y2 += len(entity_pred)
        Z2 += len(entity_true)
    
    eval_result = classification_report(true_labels,
                                        true_predictions,
                                        digits=4,
                                        mode='strict',
                                        scheme=IOB2)
    print(eval_result)
    f1, p1, r1 = 2 * X / (Y + Z), X / Y, X / Z
    f2, p2, r2 = 2 * X2 / (Y2 + Z2), X2 / Y2, X2 / Z2
    print("val-token level: f1:{}, precision: {}, recall:{}".format(f1, p1, r1))
    print("val-entity level: f1:{}, precision: {}, recall:{}".format(f2, p2, r2))

    result_dict = {
        "seqeval_result": eval_result,
        "val-token  level": {
            "f1": f1,
            "precision": p1,
            "recall": r1
        },
        "val-entity level": {
            "f1": f2,
            "precision": p2,
            "recall": r2
        }
    }

    return result_dict


def trans_entity2tuple(scores):
    '''把tensor转为(样本id, start, end, 实体类型)的tuple用于计算指标
    '''
    batch_entity_ids = set()
    for i, one_samp in enumerate(scores):
        entity_ids = []
        for j, item in enumerate(one_samp):
            flag_tag = categories_id2label[item.item()]
            if flag_tag.startswith('B-'):  # B
                entity_ids.append([i, j, j, flag_tag[2:]])

            elif len(entity_ids) == 0:
                continue

            elif (len(entity_ids[-1]) > 0) and flag_tag.startswith('I-') and \
                 (flag_tag[2:] == entity_ids[-1][-1]):  # I
                entity_ids[-1][-2] = j

            elif len(entity_ids[-1]) > 0:
                entity_ids.append([])

        for i in entity_ids:
            if i:
                batch_entity_ids.add(tuple(i))

    return batch_entity_ids


def parse_arguments():
    parser = argparse.ArgumentParser(description='Bert_Base_Chinese postprocess for sequence labeling task.')
    parser.add_argument('-i', '--result_dir', type=str, required=True,
                        help='result dir for prediction results')
    parser.add_argument('-o', '--out_path', type=str, required=True,
                        help='save path for evaluation result')
    parser.add_argument('-l', '--label_dir', type=str, required=True,
                        help='label dir for label results')
    parser.add_argument('-c', '--config_path', type=str, required=True,
                        help='config path for export model')
    parser.add_argument('-k', '--ckpt_path', type=str, default="./best_model.pt",
                        help='result dir for prediction results')
    parser.add_argument('-bs', '--batch_size', type=int, default=64, 
                        help='Batch size of output data.')
    arguments = parser.parse_args()
    arguments.out_path = os.path.abspath(arguments.out_path)
    dir_name = os.path.dirname(arguments.out_path)

    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

    return arguments


if __name__ == '__main__':
    args = parse_arguments()
    model = build_model(args.config_path, args.ckpt_path)
    categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
    categories_id2label = {i: k for i, k in enumerate(categories)}

    crf = model.crf
    evaluate_results = evaluate(args.result_dir, args.label_dir, args.batch_size)
    with open(args.out_path, 'w') as f:
        json.dump(evaluate_results, f, ensure_ascii=False, indent=4)
