auto.py 1.72 KB
Newer Older
1
from transformers import AutoConfig
2
from awq.models import *
3
from awq.models.base import BaseAWQForCausalLM
4
5
6

AWQ_CAUSAL_LM_MODEL_MAP = {
    "mpt": MptAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
7
8
    "llama": LlamaAWQForCausalLM,
    "opt": OptAWQForCausalLM
9
10
11
12
13
14
15
16
17
18
}

def check_and_get_model_type(model_dir, trust_remote_code=True):
    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
    if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
        raise TypeError(f"{config.model_type} isn't supported yet.")
    model_type = config.model_type
    return model_type

class AutoAWQForCausalLM:
19
    default_quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
20

21
22
23
24
25
    def __init__(self):
        raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
                               'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
    
    @classmethod
26
    def from_pretrained(self, model_path, trust_remote_code=True) -> BaseAWQForCausalLM:
27
28
29
30
31
        model_type = check_and_get_model_type(model_path, trust_remote_code)

        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
            model_path, model_type, trust_remote_code=trust_remote_code
        )
32
33

    @classmethod
34
    def from_quantized(self, quant_path, quant_filename, quant_config={}, 
35
                       device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
36
        model_type = check_and_get_model_type(quant_path, trust_remote_code)
37
        quant_config = quant_config if quant_config else self.default_quant_config
38

39
        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
40
            quant_path, model_type, quant_filename, quant_config, device, trust_remote_code=trust_remote_code
41
        )