create_bitblas_ckpt.py 3.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import torch
import bitblas
from transformers.utils.hub import cached_file
import os
from transformers import GenerationConfig
import time
import json

import sys

sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + "/../")
from modeling_bitnet import BitnetForCausalLM
from tokenization_bitnet import BitnetTokenizer

filepath = os.path.abspath(__file__)
dirpath = os.path.dirname(filepath)

torch.set_grad_enabled(False)
bitblas.set_log_level("INFO")

parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="1bitLLM/bitnet_b1_58-3B")
parser.add_argument("--saved_model_path", type=str, default=None)
args = parser.parse_args()

model_name_or_path = args.model_name_or_path
saved_model_path = os.path.join(
    dirpath, "models",
    f"{model_name_or_path}_bitblas") if args.saved_model_path is None else args.saved_model_path


def generate_text(model, tokenizer, prompt, max_length=100):
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.lm_head.weight.device)
    # Generate cos and sin values
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

    generation_config = GenerationConfig(
        max_length=max_length,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        num_return_sequences=1,
    )

    start_time = time.time()
    output_ids = model.generate(input_ids, generation_config=generation_config)
    end_time = time.time()

    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    generation_time = end_time - start_time
    num_tokens = len(output_ids[0])
    tokens_per_second = num_tokens / generation_time

    print(f"Generated {num_tokens} tokens in {generation_time:.2f} seconds")
    print(f"Tokens per second: {tokens_per_second:.2f}")

    return generated_text


def main():
    model = (
        BitnetForCausalLM.from_pretrained(
            model_name_or_path,
            use_flash_attention_2=False,
            torch_dtype=torch.float16,
        ).cuda().half())
    tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False)

    # print("original model generated text:")
    # print(generate_text(model, tokenizer, "Hi, ", max_length=100))
    input_ids = torch.ones((1, 1), dtype=torch.long).cuda()
    # naive model inference
    output = model(input_ids)
    print("original model output:", output)

    model.quantize(fuse_qkv=True, fuse_gateup=True)
    print("original model generated text:")
    print(generate_text(model, tokenizer, "Hi, ", max_length=100))

    model.save_pretrained(saved_model_path)

    # load quant config
    quant_config_path = cached_file(model_name_or_path, "quantize_config.json")
    with open(quant_config_path, "r") as f:
        quant_config = json.load(f)
    print("quant config:")
    print(quant_config)
    quant_config["checkpoint_format"] = "bitblas"
    quant_config["fuse_qkv"] = True
    quant_config["fuse_gateup"] = True

    # save quant config
    quant_config_path = os.path.join(saved_model_path, "quantize_config.json")
    with open(quant_config_path, "w") as f:
        json.dump(quant_config, f)
    print("quant config saved to:", quant_config_path)

    # copy benchmark filed into saved model path
    file_list = [
        "configuration_bitnet.py",
        "eval_utils.py",
        "modeling_bitnet.py",
        "tokenization_bitnet.py",
        "utils_quant.py",
        "README.md",
    ]
    for file in file_list:
        file_path = cached_file(model_name_or_path, file)
        os.system(f"cp {file_path} {saved_model_path}")
    # load quantized model
    qmodel = BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()
    print("quantized model generated text:")
    print(generate_text(qmodel, tokenizer, "Hi, ", max_length=100))


if __name__ == '__main__':
    main()