#! -*- coding:utf-8 -*-
# bert+crf用来做实体识别
# 数据集：http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]  token_level: 97.06； entity_level: 95.90


import numpy as np
import torch
from torch.utils.data import DataLoader
from apex.optimizers import FusedLAMB
import apex_C
from apex import amp
import migraphx
import torch.nn as nn
import torch.optim as optim
# from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.snippets import sequence_padding, ListDataset
from bert4torch.layers import CRF
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
from bert4torch.models import BaseModelDDP
import os
import time
import multiprocessing as mp
from multiprocessing import Process, Queue, Manager

maxlen = 256
batch_size = 64
categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
categories_id2label = {i: k for i, k in enumerate(categories)}
categories_label2id = {k: i for i, k in enumerate(categories)}

# BERT base
config_path = '/datasets/bert-base-chinese/config.json'
dict_path = '/datasets/bert-base-chinese/vocab.txt'
device = "cuda"
gpuid = os.getenv('HIP_VISIBLE_DEVICES')
labdir = os.path.join('results', gpuid, 'label')
resultdir = os.path.join('results', gpuid, 'data')
os.makedirs(resultdir, exist_ok=True)
os.makedirs(labdir, exist_ok=True)

def AllocateOutputMemory(model):
    outputData={}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
    return outputData

# 加载数据集
class MyDataset(ListDataset):
    @staticmethod
    def load_data(filename):
        D = []
        with open(filename, encoding='utf-8') as f:
            f = f.read()
            for l in f.split('\n\n'):
                if not l:
                    continue
                d = ['']
                for i, c in enumerate(l.split('\n')):
                    char, flag = c.split(' ')
                    d[0] += char
                    if flag[0] == 'B':
                        d.append([i, i, flag[2:]])
                    elif flag[0] == 'I':
                        d[-1][1] = i
                D.append(d)
        return D


# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

if os.path.isfile("/home/sunzhq/workspace/yidong/bert/bert4torch_cmcc/examples/sequence_labeling/bert_best_fp16.mxr"):
    model = migraphx.load("/home/sunzhq/workspace/yidong/bert/bert4torch_cmcc/examples/sequence_labeling/bert_best_fp16.mxr")
    print("###############migraphx-driver#####################")
else:
    # 加载模型
    maxInput={"input":[64,256]}
    model = migraphx.parse_onnx("/home/sunzhq/workspace/yidong/bert/bert4torch_cmcc/examples/sequence_labeling/bert_best.onnx", map_input_dims=maxInput)
    migraphx.quantize_fp16(model)

    # 编译
    model.compile(migraphx.get_target("gpu"),offload_copy=False, device_id=0)

inputName=list(model.get_inputs().keys())[0]
modelData=AllocateOutputMemory(model)



def collate_fn(batch):
    batch_token_ids, batch_labels = [], []
    maxlen = 256
    for d in batch:
        # # import pdb;pdb.set_trace()
        # tokens = tokenizer.tokenize(d[0], maxlen=maxlen)
        # mapping = tokenizer.rematch(d[0], tokens)
        # start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
        # end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
        # token_ids = tokenizer.tokens_to_ids(tokens)
        # labels = np.zeros(len(token_ids))
        # for start, end, label in d[1:]:
        #     if start in start_mapping and end in end_mapping:
        #         start = start_mapping[start]
        #         end = end_mapping[end]
        #         labels[start] = categories_label2id['B-'+label]
        #         labels[start + 1:end + 1] = categories_label2id['I-'+label]
        # batch_token_ids.append(token_ids)
        # batch_labels.append(labels)
    # batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    # batch_labels = torch.tensor(sequence_padding(batch_labels), dtype=torch.long, device=device)
        tokens = tokenizer.tokenize(d[0], maxlen=maxlen)  # 截断到 maxlen
        mapping = tokenizer.rematch(d[0], tokens)
        start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
        end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
        token_ids = tokenizer.tokens_to_ids(tokens)
        
        # 初始化 labels 为全 0（或根据你的设计，可能是 'O' 标签）
        labels = np.zeros(len(token_ids), dtype=np.int64)
        
        for start, end, label in d[1:]:
            if start in start_mapping and end in end_mapping:
                start_idx = start_mapping[start]
                end_idx = end_mapping[end]
                labels[start_idx] = categories_label2id['B-' + label]
                labels[start_idx + 1:end_idx + 1] = categories_label2id['I-' + label]
        
        batch_token_ids.append(token_ids)
        batch_labels.append(labels)
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=maxlen, value=tokenizer._token_pad_id), 
                                   dtype=torch.long, 
                                   device=device)
    batch_labels = torch.tensor(sequence_padding(batch_labels, length=maxlen, value=-100), 
                                dtype=torch.long, 
                                device=device)
    return batch_token_ids, batch_labels

# 转换数据集
valid_dataloader = DataLoader(MyDataset('/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn) 


def pad_data(data, seq=256):
    if len(data.shape) == 1:
        return np.pad(data, ((0, seq-data.shape[0])),
                      "constant", constant_values=(0))
    
    elif len(data.shape) == 2:
        return np.pad(data, ((0, 0), (0, seq-data.shape[1])),
                      "constant", constant_values=(0))
    
    else:
        # shape(bs, seq, len(categories))
        return np.pad(data, ((0, 0), (0, seq-data.shape[1]), (0, 0)),
                      "constant", constant_values=(0))


def pad_data_bin(data, output, bs, seq=256, len_catagory=7):
    if output == "emission_score":
        data = data.reshape((bs, -1, len_catagory))
    else:
        data = data.reshape((bs, -1))
    
    return pad_data(data, seq)

#crf = CRF(len(categories)).to(device)
#crf = CRF(len(categories))

def evaluate(data):
    X, Y, Z = 1e-10, 1e-10, 1e-10
    X2, Y2, Z2 = 1e-10, 1e-10, 1e-10
    end = 0
    infer_times = []
    total_infer_times = []
    data_idx = 0
    total_start = time.time()

    # warmup
    for token_ids, label in tqdm(data):
        data_numpy=token_ids.detach().cpu().numpy()
        img_data = np.zeros(data_numpy.shape).astype("int64")
        for i in range(data_numpy.shape[0]):
            img_data[i, :] = data_numpy[i, :]
        modelData[inputName] = migraphx.to_gpu(migraphx.argument(img_data))
        preds_dcu = model.run(modelData)
        break


    for token_ids, label in tqdm(data):
        data_numpy=token_ids.detach().cpu().numpy()
        # device = torch.device("cuda")

        # 注意：这里需要执行赋值操作，否则会造成migraphx中输入数据步长不对
        img_data = np.zeros(data_numpy.shape).astype("int64")
        for i in range(data_numpy.shape[0]):
            img_data[i, :] = data_numpy[i, :]
        if img_data.shape[0] != 64:
            break
        modelData[inputName] = migraphx.to_gpu(migraphx.argument(img_data))

        start = time.time()
        # result = model.run({"input":img_data})
        preds_dcu = model.run(modelData)
        end += time.time() - start
        infer_times.append(time.time() - start)
        # print("******************:", 64/infer_times[-1])
        total_infer_times.append(time.time() - total_start)
        result_1 = np.array(migraphx.from_gpu(preds_dcu[0]))
        result_2 = np.array(migraphx.from_gpu(preds_dcu[1]))
        emission_score = torch.from_numpy(np.array(result_1, copy=False))
        attention_mask = torch.from_numpy(np.array(result_2, copy=False))
        labels = label.cpu().numpy()
        # emission_score = torch.from_numpy(np.array(result[0], copy=False))
        # attention_mask = torch.from_numpy(np.array(result[1], copy=False))
        

        # 保存bin文件
        labels = np.pad(labels, ((0, batch_size-labels.shape[0]), (0,0)), 'constant', constant_values=-1)
        labels.tofile(f'{labdir}/{data_idx}.bin')
        
        emission_score = np.pad(emission_score, ((0, batch_size-emission_score.shape[0]), (0,0), (0,0)), 'constant')
        attention_mask = np.pad(attention_mask, ((0, batch_size-attention_mask.shape[0]), (0,0)), 'constant')
        emission_score.tofile(f'{resultdir}/{data_idx}_0.bin')
        attention_mask.tofile(f'{resultdir}/{data_idx}_1.bin')

        labels = pad_data_bin(labels, "labels", batch_size)
        emission_score = pad_data_bin(emission_score, "emission_score", batch_size)
        attention_mask = pad_data_bin(attention_mask, "attention_mask", batch_size)

        labels = torch.Tensor(labels)
        # mask last data
        data_mask = labels[:, 0] >= 0
        labels = labels[data_mask]
        emission_score = torch.Tensor(emission_score)[data_mask]
        attention_mask = torch.Tensor(attention_mask)[data_mask]

        scores = crf.decode(emission_score, attention_mask)

        true_label = []
        for label in labels:
            true_label += [categories_id2label[int(l)] for l in label if l != -100]
        
        attention_mask = labels.gt(0)
        # token粒度
        X += (scores.eq(labels) * attention_mask).sum().item()
        Y += scores.gt(0).sum().item()
        Z += labels.gt(0).sum().item()

        # entity粒度
        entity_pred = trans_entity2tuple(scores)
        entity_true = trans_entity2tuple(labels)
        X2 += len(entity_pred.intersection(entity_true))
        Y2 += len(entity_pred)
        Z2 += len(entity_true)

        data_idx += 1
        total_start = time.time()
    print("total_sample_data:", (64 * data_idx))
    #avg_infer_time = sum(infer_times[1:]) / len(infer_times[1:])
    avg_infer_fps = 64 * len(infer_times) / sum(infer_times)
    print(f"total_infer_time: {end}s")
    print(f'avg_infer_fps: {avg_infer_fps}samples/s')
    load_data_infer_time = sum(total_infer_times)
    load_data_avg_infer_fps = len(total_infer_times) * 64 / sum(total_infer_times)
    print(f'load_data_total_infer_time: {load_data_infer_time}s')
    print(f'load_data_avg_total_Infer_fps: {load_data_avg_infer_fps} samples/s')
    

    f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
    f2, precision2, recall2 = 2 * X2 / (Y2 + Z2), X2/ Y2, X2 / Z2
    return f1, precision, recall, f2, precision2, recall2

# def trans_entity2tuple(scores):
#     '''把tensor转为(样本id, start, end, 实体类型)的tuple用于计算指标
#     '''
#     batch_entity_ids = set()
#     for i, one_samp in enumerate(scores):
#         entity_ids = []
#         for j, item in enumerate(one_samp):
#             flag_tag = categories_id2label[item.item()]
#             if flag_tag.startswith('B-'):  # B
#                 entity_ids.append([i, j, j, flag_tag[2:]])
#             elif len(entity_ids) == 0:
#                 continue
#             elif (len(entity_ids[-1]) > 0) and flag_tag.startswith('I-') and (flag_tag[2:]==entity_ids[-1][-1]):  # I
#                 entity_ids[-1][-2] = j
#             elif len(entity_ids[-1]) > 0:
#                 entity_ids.append([])

#         for i in entity_ids:
#             if i:
#                 batch_entity_ids.add(tuple(i))
#     return batch_entity_ids


def trans_entity2tuple(scores):
    '''把tensor转为(样本id, start, end, 实体类型)的tuple用于计算指标'''
    batch_entity_ids = set()
    for i, one_samp in enumerate(scores):
        entity_ids = []
        for j, item in enumerate(one_samp):
            # 跳过 padding / ignore_index (-100)
            if item == -100:
                continue

            # 安全地获取标签名（确保 key 是 int）
            tag_id = int(item.item())  # 转为 int，避免 float key
            flag_tag = categories_id2label[tag_id]

            if flag_tag.startswith('B-'):  # B
                entity_ids.append([i, j, j, flag_tag[2:]])
            elif len(entity_ids) == 0:
                continue
            elif (len(entity_ids[-1]) > 0) and flag_tag.startswith('I-') and (flag_tag[2:] == entity_ids[-1][-1]):  # I
                entity_ids[-1][-2] = j
            elif len(entity_ids[-1]) > 0:
                entity_ids.append([])

        for ent in entity_ids:
            if ent:  # 非空才加入
                batch_entity_ids.add(tuple(ent))
    return batch_entity_ids


class Model(BaseModel):
    def __init__(self, config_path):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=None, segment_vocab_size=0)
        # embedding_dims:768, len_categories: 7
        self.fc = nn.Linear(768, 7)  # 包含首尾
        self.crf = CRF(7)

    def forward(self, token_ids):
        sequence_output = self.bert([token_ids])  # [btz, seq_len, hdsz]
        emission_score = self.fc(sequence_output)  # [btz, seq_len, tag_size]
        attention_mask = token_ids.gt(0).long()
        return emission_score, attention_mask


def build_model(config_path, checkpoint_path):
    model = Model(config_path).to("cpu")
    model.load_weights(checkpoint_path, strict=False)
    return model

if __name__ == '__main__':
    ptmodel = build_model("/datasets/bert-base-chinese/config.json", "./best_model.pt")
    crf = ptmodel.crf

    # time_fw为存储时间日志的文件对象，文件绝对路径为'log/time/time.txt'
    time_fw = open(os.path.join('log/', f'time.txt'), 'a', encoding='utf-8')

    # time_fw写入程序开始执行的时间
    time_fw.write('Start Time: {:.6f}\n'.format(time.time()))

    f1, precision, recall, f2, precision2, recall2 = evaluate(valid_dataloader)

    print(f'[val-token  level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}')
    print(f'[val-entity level] f1: {f2:.5f}, p: {precision2:.5f} r: {recall2:.5f}\n')

    # time_fw写入程序开始执行的时间
    time_fw.write('End Time: {:.6f}\n'.format(time.time()))
    time_fw.flush()
    time_fw.close()
