from __future__ import print_function

import sys
import ctypes

sys.setdlopenflags(sys.getdlopenflags() | ctypes.RTLD_GLOBAL)

from torch.nn.utils.rnn import pad_sequence
import random
import os
import re
import sys
import argparse
import timeit
import torch

torch.manual_seed(42)

dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(dir_path + "/../../..")

from icetk_glm_130B import _IceTokenizer
tokenizer = _IceTokenizer()

def tokenize(contexts):
    def encode(raw_text):
        # add MASK
        generation_mask = "[MASK]" if "[MASK]" in raw_text else "[gMASK]"
        use_gmask = "[MASK]" not in raw_text

        mask_pattern = r"\[g?MASK\]"
        text_list = re.split(mask_pattern, raw_text)
        pattern_list = re.compile(mask_pattern).findall(raw_text)
        seq = []
        for i in range(len(pattern_list)):
            pattern = pattern_list[i]
            sub_text = text_list[i]
            seq.extend(tokenizer.tokenize(sub_text))
            seq.append(tokenizer.get_command(pattern))

        seq.extend(tokenizer.tokenize(text_list[-1]))

        if 'MASK]' not in raw_text:
            seq += [tokenizer.get_command(generation_mask)]
            raw_text += ' ' + generation_mask
        if not raw_text.endswith('MASK]'):
            seq = seq + [tokenizer.get_command('eos')]
        seq = seq + [tokenizer.get_command('sop')]
        # if args.local_rank == 0:
        #     print('raw text: {}\n'.format(raw_text))
        #     print(seq)
        # if len(seq) > args.max_seq_len:
        #     raise ValueError('text too long.')
        return torch.IntTensor(seq), -1 if use_gmask else seq.index(tokenizer.get_command(generation_mask))

    def get_ids(contexts):
        print(contexts)
        start_ids, mask_positions = zip(*[encode(c) for c in contexts])
        print(start_ids)
        start_lengths = torch.IntTensor([len(ids) for ids in start_ids])
        start_ids = pad_sequence(start_ids, batch_first=True, padding_value=0)
        print(start_lengths, torch.IntTensor([len(ids) for ids in start_ids]))
        return start_ids, start_lengths, torch.IntTensor(mask_positions)
    
    return get_ids(contexts)
        
def get_res(tokens_batch, start_lengths):
    res = []

    if tokens_batch is not None:
        # tokens_batch = tokens_batch.cpu().numpy()
        for i, tokens in enumerate(tokens_batch):         
            # for beam_id in range(beam_width):
            beam_id = 0
            token = list(tokens[beam_id][start_lengths[0]:]) # exclude context input from the output
            if 20002 in token:
                token = token[:token.index(20002)]
            if 150005 in token:
                token = token[:token.index(150005)]
            # print("token_out: ", token)
            res.append(tokenizer.detokenize(token))

    return res


if __name__ == "__main__":

    # end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
    # print("end_tokens: ", end_tokens)  # 150005, 20002
    # start_ids, start_lengths, mask_positions = tokenize(["晚上睡不着怎么办"])
    # print("start_ids: ", start_ids)
    # print("start_lengths: ", start_lengths)
    # print("mask_positions: ", mask_positions)

    filename = "out"  # 文件名
    tokens_beam_batch = []
    tokens_batch = []  # 存储数字的列表
    with open(filename, 'r') as f:
        lines = f.readlines()  # 读取文件中的所有行
        for line in lines:
            tokens_batch.append([int(x) for x in line.split()])
    tokens_beam_batch.append(tokens_batch)
    
    print(tokens_beam_batch)
    start_lengths = [0]
    res = get_res(tokens_beam_batch, start_lengths)

    print(res)
    
