"driver/vscode:/vscode.git/clone" did not exist on "9e0d61465612000caf2050cf834f4693792c5770"
test_quant.py 2.56 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from transformers import AutoTokenizer
import torch
import logging
import json
import time
def gptq():

    # Specify paths and hyperparameters for quantization
    model_path = "../../models/qwen/Qwen1___5-7B-Chat"
    quant_path = "./Qwen1.5-7B-4bit-gptq-4"

    quantize_config = BaseQuantizeConfig(
        bits=4, # 4 or 8
        group_size=128,
        damp_percent=0.01,
        desc_act=False,  # set to False can significantly speed up inference but the perplexity may slightly bad
        static_groups=False,
        sym=True,
        true_sequential=True,
        model_name_or_path=None,
        model_file_base_name="model"
    )
    max_len = 8192
    
    # Load your tokenizer and model with AutoGPTQ
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoGPTQForCausalLM.from_pretrained(model_path, quantize_config)

    # https://huggingface.co/datasets/llm-wizard/alpaca-gpt4-data-zh
    file_path = '../alpaca_gpt4_data_zh.json'
    # file_path = 'oaast_rm_zh.json'
    messages = []
    with open(file_path, 'r',encoding='UTF-8') as fcc_file:
        fcc_data = json.load(fcc_file)
        for tmp_data in fcc_data:
            instruction = tmp_data['instruction']
            input = tmp_data['input']
            output = tmp_data['output']
            for op in output:
                msg =  [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": instruction + input},
                        {"role": "assistant", "content": op}
                    ]
                messages.append(msg)
                
    print('len(messages):',len(messages))
    messages = messages[:5]
    print('len(messages):',len(messages))
    
    data = []
    for msg in messages:
        text = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=False)
        # print(text + '\n')
        model_inputs = tokenizer([text])
        input_ids = torch.tensor(model_inputs.input_ids[:max_len], dtype=torch.int)
        data.append(dict(input_ids=input_ids, attention_mask=input_ids.ne(tokenizer.pad_token_id)))
    print('len(data):',len(data))
    t1 = time.time()
    logging.basicConfig(format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
    model.quantize(data, cache_examples_on_gpu=False)
    model.save_quantized(quant_path, use_safetensors=True)
    tokenizer.save_pretrained(quant_path)
    t2 = time.time()
    print(('time:{:.2f}s').format(t2-t1))

gptq()