quant_custom_data.py 1.19 KB
Newer Older
Casper's avatar
Casper committed
1
2
3
4
5
6
7
8
9
10
11
12
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Casper's avatar
Casper committed
13
# Define data loading methods
Casper's avatar
Casper committed
14
15
16
17
18
19
20
21
22
23
def load_dolly():
    data = load_dataset('databricks/databricks-dolly-15k', split="train")

    # concatenate data
    def concatenate_data(x):
        return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
    
    concatenated = data.map(concatenate_data)
    return [text for text in concatenated["text"]]

Casper's avatar
Casper committed
24
25
26
27
def load_wikitext():
    data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
    return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]

Casper's avatar
Casper committed
28
# Quantize
Casper's avatar
Casper committed
29
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
Casper's avatar
Casper committed
30
31
32
33
34
35

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')