#! -*- coding:utf-8 -*-
# bert+crf用来做实体识别
# 数据集：http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]  token_level: 97.06； entity_level: 95.90


import numpy as np
import torch
from torch.utils.data import DataLoader
from apex.optimizers import FusedLAMB
import apex_C
from apex import amp

import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.layers import CRF
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
from bert4torch.models import BaseModelDDP
import os
import time

maxlen = 256
batch_size = 64
categories = ['O', 'B-LOC', 'I-LOC', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG']
categories_id2label = {i: k for i, k in enumerate(categories)}
categories_label2id = {k: i for i, k in enumerate(categories)}

# BERT base
config_path = '/datasets/bert-base-chinese/config.json'
checkpoint_path = "/models/best_model.pt"
dict_path = '/datasets/bert-base-chinese/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#local_rank = int(os.environ['LOCAL_RANK'])
#print("local_rank ", local_rank)
#torch.cuda.set_device(local_rank)
#device = torch.device("cuda", local_rank)
#torch.distributed.init_process_group(backend='nccl')

# 固定seed
seed_everything(42)

# 加载数据集
class MyDataset(ListDataset):
    @staticmethod
    def load_data(filename):
        D = []
        with open(filename, encoding='utf-8') as f:
            f = f.read()
            for l in f.split('\n\n'):
                if not l:
                    continue
                d = ['']
                for i, c in enumerate(l.split('\n')):
                    char, flag = c.split(' ')
                    d[0] += char
                    if flag[0] == 'B':
                        d.append([i, i, flag[2:]])
                    elif flag[0] == 'I':
                        d[-1][1] = i
                D.append(d)
        return D


# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

def collate_fn(batch):
    batch_token_ids, batch_labels = [], []
    for d in batch:
        tokens = tokenizer.tokenize(d[0], maxlen=maxlen)
        mapping = tokenizer.rematch(d[0], tokens)
        start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
        end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
        token_ids = tokenizer.tokens_to_ids(tokens)
        labels = np.zeros(len(token_ids))
        for start, end, label in d[1:]:
            if start in start_mapping and end in end_mapping:
                start = start_mapping[start]
                end = end_mapping[end]
                labels[start] = categories_label2id['B-'+label]
                labels[start + 1:end + 1] = categories_label2id['I-'+label]
        batch_token_ids.append(token_ids)
        batch_labels.append(labels)
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_labels = torch.tensor(sequence_padding(batch_labels), dtype=torch.long, device=device)
    return batch_token_ids, batch_labels

# 转换数据集
#train_dataset = MyDataset('/workspace/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.train')
#train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
#train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('/datasets/bert-base-chinese/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn) 

# 定义bert上的模型结构
class Model(BaseModel):
    def __init__(self):
        super().__init__()
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
        self.fc = nn.Linear(768, len(categories))  # 包含首尾
        self.crf = CRF(len(categories))

    def forward(self, token_ids):
        sequence_output = self.bert([token_ids])  # [btz, seq_len, hdsz]
        emission_score = self.fc(sequence_output)  # [btz, seq_len, tag_size]
        attention_mask = token_ids.gt(0).long()
        return emission_score, attention_mask

    def predict(self, token_ids):
        self.eval()
        with torch.no_grad():
            emission_score, attention_mask = self.forward(token_ids)
            best_path = self.crf.decode(emission_score, attention_mask)  # [btz, seq_len]
        return best_path

model = Model().to(device)

## 指定DDP模型使用多gpu, master_rank为指定用于打印训练过程的local_rank

class Loss(nn.Module):
    def forward(self, outputs, labels):
        return model.module.crf(*outputs, labels)

#try to use apex 
optimizer = optim.Adam(model.parameters(), lr=6e-5)
#model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", master_weights=True, verbosity=0)
#model = BaseModelDDP(model, master_rank=0, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False)

model.compile(
    loss=Loss(),
    optimizer=optimizer,  
    # use_apex=True, #此处设置是否采用apex_amp的混合精度
)
#------------------------------------------------------------

def evaluate(data):
    # for token_ids, label in tqdm(data):
    #     #torch.onnx.export(model.module, token_ids, "./bert_best.onnx", opset_version=13,
    #     torch.onnx.export(model, token_ids, "./bert_best.onnx", opset_version=13,
    #                 input_names=['input'],                       
    #                 output_names=['output'],                      
    #                 dynamic_axes={'input': {1: 'token'}})            # 第一维可变，第0维默认维batch
    #     print("完成onnx模型转换")
    #     break
    model.eval()
    dummy_input = torch.randint(1, 2000, size=(64, 256), dtype=torch.long, device=device)
    
    torch.onnx.export(
        model,
        dummy_input,
        "bert_best_static.onnx",
        opset_version=13,
        input_names=["input"],
        output_names=["emission_scores", "attention_mask"],  # 更准确的输出名
        do_constant_folding=True,
    )
    print("✅ 静态 ONNX 导出完成！")


    # for token_ids, label in tqdm(data):
    #     #torch.onnx.export(model.module, token_ids, "./bert_best.onnx", opset_version=13,
    #     torch.onnx.export(model, token_ids, "./bert_best_1.onnx", opset_version=13,
    #                 input_names=['input'],                       
    #                 output_names=['output'],                      
    #                 dynamic_axes={'input': {1: 'token'}}, # 第一维可变，第0维默认维batch
    #                 do_constant_folding=True)   # 启用常量折叠，减少运行时计算           
    #     print("完成onnx模型转换")
    #     break


if __name__ == '__main__':
    # time_fw为存储时间日志的文件对象，文件绝对路径为'log/time/time.txt'
    time_fw = open(os.path.join('log/', f'time.txt'), 'a', encoding='utf-8')
    # time_fw写入程序开始执行的时间
    time_fw.write('Start Time: {:.6f}\n'.format(time.time()))

    model.load_weights("/models/best_model.pt")

    evaluate(valid_dataloader)
    
