entry.py 3.77 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import torch
3
4
import argparse
from lm_eval import evaluator
5
6
from transformers import AutoTokenizer
from awq.models.auto import AutoAWQForCausalLM
Casper Hansen's avatar
Casper Hansen committed
7
8
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
9
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
Ji Lin's avatar
Ji Lin committed
10

11

Casper Hansen's avatar
Casper Hansen committed
12
13
def load_search_result_into_memory(model, search_path):
    awq_results = torch.load(search_path, map_location="cpu")
Casper Hansen's avatar
Casper Hansen committed
14
            
Casper Hansen's avatar
Casper Hansen committed
15
16
    apply_scale(model, awq_results["scale"])
    apply_clip(model, awq_results["clip"])
Ji Lin's avatar
Ji Lin committed
17

18
def run_search(model_path, dump_path, w_bit, q_config):
19
20
21
22
    """
    Step 1/2: Search the pile for an optimal scaling factor.
    """
    # Load model
23
24
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Ji Lin's avatar
Ji Lin committed
25

26
27
28
29
30
    # Quantize
    model.quantize(tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)

    # Save search results
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
31

32
def run_quant(model_path, search_path, dump_path, w_bit, q_config):
33
34
35
36
    """
    Step 2/2: Use the search results to quantize model weights
    """
    # Load model and search results
37
38
    model = AutoAWQForCausalLM.from_pretrained(model_path)
    load_search_result_into_memory(model.model, search_path)
Ji Lin's avatar
Ji Lin committed
39

40
41
42
43
44
    # Run actual weight quantization
    model.quantize(w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)

    # Save quantized model
    model.save_quantized(dump_path)
Ji Lin's avatar
Ji Lin committed
45

46
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
47
48
49
50
    """
    Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
    """
    # Load model
51
52
    model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
53

54
    # Load adapter
55
    lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
56
57

    # Evaluate perplexity of quantized model
58
59
60
61
62
63
64
65
66
67
68
    results = evaluator.simple_evaluate(
        model=lm_eval_model,
        tasks=['wikitext'],
        batch_size=1,
        no_cache=True,
        num_fewshot=0,
    )

    print(evaluator.make_table(results))

if __name__ == '__main__':
69
70
71
72
73
    """
    python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
    python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/pytorch_model.bin --quant_path mpt-7b-8k-chat-awq
    python -m awq.entry --entry_type perplexity --model_path mosaicml/mpt-7b-8k-chat --quant_path mpt-7b-8k-chat-awq
    """
74
75
76
77
78
    parser = argparse.ArgumentParser()
    parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|perplexity)')
    parser.add_argument('--model_path', type=str, help='Path to hf model')
    parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
    parser.add_argument('--quant_path', type=str, help='Path to save/load AWQ quant model')
79
    parser.add_argument('--device', type=str, default='balanced', help='Device to load model to')
80
81
82
83
84
85
86
87
88
89
90
    parser.add_argument('--w_bit', type=int, default=4)
    parser.add_argument('--q_group_size', type=int, default=128)
    args = parser.parse_args()

    q_config = { "zero_point": True, "q_group_size": args.q_group_size }
    
    if args.entry_type == 'search':
        run_search(args.model_path, args.search_path, args.w_bit, q_config)
    elif args.entry_type == 'quant':
        run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
    elif args.entry_type == 'perplexity':
91
        run_perplexity(args.model_path, args.quant_path, args.w_bit, q_config, args.device)
92
93
    else:
        raise Exception('--entry_type must be one of (search|quant|perplexity)')