import os
import torch
import argparse


class GlmWeights(object):
    def __init__(self, head_num, size_per_head, layer_num, vocab_size, max_seq_len, tensor_para_size, pipeline_para_size, dtype):
        assert(head_num % tensor_para_size == 0)
        self.head_num = head_num
        self.size_per_head = size_per_head
        self.layer_num = layer_num
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.tensor_para_size = tensor_para_size
        self.pipeline_para_size = pipeline_para_size
        self.layers_per_device = layer_num // pipeline_para_size

        local_head_num = head_num // tensor_para_size
        global_head_num = head_num                              # 96
        local_hidden_units = local_head_num * size_per_head     # 3072
        global_hidden_units = global_head_num * size_per_head   # 12288
        local_inter_size = local_hidden_units * 8 // 3          # 8192

        self.local_head_num = local_head_num
        self.global_head_num = global_head_num
        self.local_hidden_units = local_hidden_units
        self.global_hidden_units = global_hidden_units
        self.local_inter_size = local_inter_size
        self.dtype = dtype

        self.w = []
        self.weight = []
        self.scale = []
        # Transformer blocks
        self.w.extend([torch.zeros(3 * local_hidden_units, dtype = torch.float16)] * layer_num)    #9216                               # attention.query_key_value.bias
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)       #12288                                  # attention.dense.bias
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)                                   # input_layernorm.bias
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)                                   # input_layernorm.weight
        self.w.extend([torch.zeros(local_inter_size, dtype = torch.float16)] * layer_num)          #8192                         # mlp.dense_h_to_4h.bias.1
        self.w.extend([torch.zeros(local_inter_size, dtype = torch.float16)] * layer_num)                                   # mlp.dense_h_to_4h.bias.2
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)       #12288                            # mlp.dense_4h_to_h.bias
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)                                   # post_attention_layernorm.bias
        self.w.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)                                   # post_attention_layernorm.weight
        
        if dtype in ['fp16', 'int8']:
            w_type = torch.int8 if dtype == 'int8' else torch.float16
            self.weight.extend([torch.zeros(global_hidden_units * 3 * local_hidden_units, dtype = w_type)] * layer_num)  #113246208           # attention.query_key_value.weight
            self.weight.extend([torch.zeros(local_hidden_units * global_hidden_units, dtype = w_type)] * layer_num)      #37748736                             # attention.dense.weight
            self.weight.extend([torch.zeros(global_hidden_units * local_inter_size, dtype = w_type)] * layer_num)        #100663296                           # mlp.dense_h_to_4h.weight.1
            self.weight.extend([torch.zeros(global_hidden_units * local_inter_size, dtype = w_type)] * layer_num)        #100663296                           # mlp.dense_h_to_4h.weight.2
            self.weight.extend([torch.zeros(local_inter_size * global_hidden_units, dtype = w_type)] * layer_num)        #100663296     

        else:
            self.weight.extend([torch.zeros(global_hidden_units * 3 * local_hidden_units // 2, dtype = torch.int8)] * layer_num) #56623104            # attention.query_key_value.weight
            self.weight.extend([torch.zeros(local_hidden_units * global_hidden_units // 2, dtype = torch.int8)] * layer_num)     #18874368                              # attention.dense.weight
            self.weight.extend([torch.zeros(global_hidden_units * local_inter_size // 2, dtype = torch.int8)] * layer_num)       #50331648                           # mlp.dense_h_to_4h.weight.1
            self.weight.extend([torch.zeros(global_hidden_units * local_inter_size // 2, dtype = torch.int8)] * layer_num)       #50331648                            # mlp.dense_h_to_4h.weight.2
            self.weight.extend([torch.zeros(local_inter_size * global_hidden_units // 2, dtype = torch.int8)] * layer_num)       #50331648                            # mlp.dense_4h_to_h.weight
        
        # scale
        if dtype in ['int8', 'int4']:
            self.scale.extend([torch.zeros(3 * local_hidden_units, dtype = torch.float16)] * layer_num)  #9216
            self.scale.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)     #12288  
            self.scale.extend([torch.zeros(local_inter_size, dtype = torch.float16)] * layer_num)        #8192
            self.scale.extend([torch.zeros(local_inter_size, dtype = torch.float16)] * layer_num)
            self.scale.extend([torch.zeros(global_hidden_units, dtype = torch.float16)] * layer_num)
        
        # After Transformer blocks
        self.w.append(torch.zeros(global_hidden_units, dtype = torch.float16))   # layernorm_gamma final_layernorm.weight
        self.w.append(torch.zeros(global_hidden_units, dtype = torch.float16))   # layernorm_beta  final_layernorm.bias
        self.w.append(torch.zeros(vocab_size * global_hidden_units // tensor_para_size, dtype = torch.float16)) #462422016  # embedding_table model.wte


    def __getitem__(self, idx):
        return self.w[idx]

    def __setitem__(self, idx, val):
        self.w[idx] = val

    def __len__(self):
        return len(self.w)

    def _map(self, func):
        for w in [self.w, self.weight, self.scale]:
            for i in range(len(w)):
                if isinstance(w[i], list):
                    for j in range(len(w[i])):
                        w[i][j] = func(w[i][j])
                else:
                    w[i] = func(w[i])

    def convert(self, ckpt_path, tensor_para_rank, pipeline_para_rank, output_dir="."):
        if not os.path.exists(ckpt_path):
            return False

        checkpoint_name = os.path.join(ckpt_path, 'mp_rank_{:02d}_model_states.pt'.format(tensor_para_rank))

        module = torch.load(checkpoint_name, map_location='cpu')['module']

        # Load
        num_attention_heads = 96
        tensor_model_parallel_size = self.tensor_para_size
        layer_num = self.layer_num

        w = []
        weight = []
        scale = []
        # Load

        num_splits = 3

        hidden_dim, local_dim = module['transformer.layers.0.attention.query_key_value.weight'].T.shape
        local_dim = local_dim // num_splits
        head_num = num_attention_heads
        size_per_head = hidden_dim // head_num
        if self.dtype == 'int4':
            size_per_head *= 2
        head_num = head_num // tensor_model_parallel_size
        if self.dtype in ['int8', 'int4']:
            scale.extend([module[f'transformer.layers.{i}.attention.query_key_value.weight_scale'].reshape(head_num, num_splits, size_per_head).permute(1, 0, 2).reshape(3, local_dim) for i in range(layer_num)])
            weight.extend([module[f'transformer.layers.{i}.attention.query_key_value.weight'].T.reshape(hidden_dim, head_num, num_splits, size_per_head).permute(0, 2, 1, 3).reshape(hidden_dim, 3 * local_dim).T for i in range(layer_num)])
        else:
            weight.extend([module[f'transformer.layers.{i}.attention.query_key_value.weight'].T.reshape(hidden_dim, head_num, num_splits, size_per_head).permute(0, 2, 1, 3).reshape(hidden_dim, 3 * local_dim) for i in range(layer_num)])

        local_dim = module['transformer.layers.0.attention.query_key_value.bias'].shape[0] // num_splits
        head_num = num_attention_heads // tensor_model_parallel_size
        size_per_head = local_dim // head_num
        w.extend([module[f'transformer.layers.{i}.attention.query_key_value.bias'].reshape(head_num, num_splits, size_per_head).permute(1, 0, 2).reshape(3, local_dim) for i in range(layer_num)])

        if self.dtype in ['int8', 'int4']:
            scale.extend([module[f'transformer.layers.{i}.attention.dense.weight_scale'] for i in range(layer_num)])
            weight.extend([module[f'transformer.layers.{i}.attention.dense.weight'] for i in range(layer_num)])
        else:
            weight.extend([module[f'transformer.layers.{i}.attention.dense.weight'].T for i in range(layer_num)])
        
        w.extend([module[f'transformer.layers.{i}.attention.dense.bias'] for i in range(layer_num)])
        w.extend([module[f'transformer.layers.{i}.input_layernorm.bias'] for i in range(layer_num)])
        w.extend([module[f'transformer.layers.{i}.input_layernorm.weight'] for i in range(layer_num)])


        local_dim = int(module['transformer.layers.0.mlp.dense_h_to_4h.weight'].shape[0] / 2)
        
        if self.dtype in ['int8', 'int4']:
            scale.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight_scale'][:local_dim] for i in range(layer_num)])
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight'][:local_dim,:] for i in range(layer_num)])
        else:
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight'][:local_dim,:].T for i in range(layer_num)])
        
        w.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.bias'][:local_dim] for i in range(layer_num)])
        
        if self.dtype in ['int8', 'int4']:
            scale.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight_scale'][local_dim:] for i in range(layer_num)])
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight'][local_dim:,:] for i in range(layer_num)])
        else:
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.weight'][local_dim:,:].T for i in range(layer_num)])
        
        w.extend([module[f'transformer.layers.{i}.mlp.dense_h_to_4h.bias'][local_dim:] for i in range(layer_num)])

        
        if self.dtype in ['int8', 'int4']:
            scale.extend([module[f'transformer.layers.{i}.mlp.dense_4h_to_h.weight_scale'] for i in range(layer_num)])
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_4h_to_h.weight'] for i in range(layer_num)])
        else:
            weight.extend([module[f'transformer.layers.{i}.mlp.dense_4h_to_h.weight'].T for i in range(layer_num)])
        
        w.extend([module[f'transformer.layers.{i}.mlp.dense_4h_to_h.bias'] for i in range(layer_num)])
        w.extend([module[f'transformer.layers.{i}.post_attention_layernorm.bias'] for i in range(layer_num)])
        w.extend([module[f'transformer.layers.{i}.post_attention_layernorm.weight'] for i in range(layer_num)])

        w.append(module[f'transformer.final_layernorm.weight'])
        w.append(module[f'transformer.final_layernorm.bias'])
        w.append(module[f'transformer.word_embeddings.weight'])

        print(len(w))

        fileprefix = output_dir + "/GPU-" + str(tensor_para_rank) + "-"

        def save_to_file(self_w, name):
            filename = fileprefix + name
            with open(filename, 'wb') as f:
                for tensor in self_w:
                    # 将torch.Tensor转换为numpy.ndarray
                    ndarray = tensor.numpy()

                    # 将numpy.ndarray以二进制形式写入文件
                    ndarray.tofile(f)  

        # Reshape
        def w_reshape(w,self_w):
            for i in range(len(w)):
                if w[i].nelement() > 0:
                    try:
                        self_w[i] = w[i].reshape(self_w[i].shape)
                    except:
                        raise RuntimeError("shape error")

        w_reshape(w, self.w)
        w_reshape(weight, self.weight)

        if self.dtype in ['int8', 'int4']:
            w_reshape(scale, self.scale)

        save_to_file(self.w, "weights")
        save_to_file(self.weight, "quant_weights")
        save_to_file(self.scale, "quant_scale")

            

        return True
    
def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
    parser.add_argument("-o", "--output-folder", default=None, type=str, help="Output FT model folder")

    args = parser.parse_args()

    
    if args.input_folder is None or args.output_folder is None:
        print("sat_path or ft_dir is not set!")
        parser.print_help()
        exit()

    return args

if __name__ == "__main__":
    tensor_para_size = 8
    precision = "fp16"
    # ft_dir = "/parastor/home/zhouxiang/glm-130b-models/glm-130b-sat-{}card-{}/ft_models".format(tensor_para_size, precision)
    # sat_path = "/parastor/home/zhouxiang/glm-130b-models/glm-130b-sat-{}card-{}/49300".format(tensor_para_size, precision)

    args = parse_arguments()
    
    ft_dir = args.output_folder
    sat_path = args.input_folder
    print("input_model: ", sat_path)
    print("output_model: ", ft_dir)


    weights = GlmWeights(96, 128, 70, 150528, 10000, tensor_para_size, 1, precision)
    for i in range(tensor_para_size):
        weights.convert(sat_path, tensor_para_rank=i, pipeline_para_rank=0, output_dir=ft_dir)

    
