entry.py 2.12 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import torch
Casper Hansen's avatar
Casper Hansen committed
3
4
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
Casper Hansen's avatar
Casper Hansen committed
5
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
Ji Lin's avatar
Ji Lin committed
6

Casper Hansen's avatar
Casper Hansen committed
7
max_memory = [v.split(':') for v in (None or [])]
8
9
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}

Casper Hansen's avatar
Casper Hansen committed
10
11
12
13
14
15
16
def get_awq_model(model):
    from awq.models import MptAWQForCausalLM

    if "mpt" in str(model.__class__).lower():
        return MptAWQForCausalLM()
    else:
        raise NotImplementedError(type(model))
Ji Lin's avatar
Ji Lin committed
17

Casper Hansen's avatar
Casper Hansen committed
18
def load_unquantized(model_path):
19
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
Casper Hansen's avatar
Casper Hansen committed
20
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
21

Casper Hansen's avatar
Casper Hansen committed
22
23
24
    kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
    model = AutoModelForCausalLM.from_pretrained(
        model_path, config=config, trust_remote_code=True, **kwargs)
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
25

Casper Hansen's avatar
Casper Hansen committed
26
    model.eval()
27

Casper Hansen's avatar
Casper Hansen committed
28
    return model, tokenizer
Abhinav Kulkarni's avatar
Abhinav Kulkarni committed
29

Casper Hansen's avatar
Casper Hansen committed
30
31
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
Casper Hansen's avatar
Casper Hansen committed
32
            
Casper Hansen's avatar
Casper Hansen committed
33
34
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
35

Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
def run_search(model, dump_path):
    model, tokenizer = load_unquantized(model_path)
    awq_model = get_awq_model(model)
    awq_results = awq_model.quantize(model, tokenizer, w_bit=4, q_config=q_config, run_search=True, run_quant=False)
Ji Lin's avatar
Ji Lin committed
40

Casper Hansen's avatar
Casper Hansen committed
41
42
43
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(awq_results, dump_path)
Ji Lin's avatar
Ji Lin committed
44

Casper Hansen's avatar
Casper Hansen committed
45
46
47
def run_quant(model, search_path, dump_path):
    model, tokenizer = load_unquantized(model_path)
    load_search_result_into_memory(model, search_path)
Ji Lin's avatar
Ji Lin committed
48

Casper Hansen's avatar
Casper Hansen committed
49
50
    awq_model = get_awq_model(model)
    awq_model.quantize(model, w_bit=4, q_config=q_config, run_search=False, run_quant=True)
Ji Lin's avatar
Ji Lin committed
51

Casper Hansen's avatar
Casper Hansen committed
52
53
54
    dirpath = os.path.dirname(dump_path)
    os.makedirs(dirpath, exist_ok=True)
    torch.save(model.cpu().state_dict(), dump_path)
Ji Lin's avatar
Ji Lin committed
55

Casper Hansen's avatar
Casper Hansen committed
56
57
58
59
model_path = "./mpt-7b-8k-chat"
search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": 128 }