auto.py 2.65 KB
Newer Older
1
import os
2
from transformers import AutoConfig
3
from awq.models import *
4
from awq.models.base import BaseAWQForCausalLM
5
6
7

AWQ_CAUSAL_LM_MODEL_MAP = {
    "mpt": MptAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
8
    "llama": LlamaAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
9
10
    "opt": OptAWQForCausalLM,
    "RefinedWeb": FalconAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
11
    "RefinedWebModel": FalconAWQForCausalLM,
Casper's avatar
Casper committed
12
    "falcon": FalconAWQForCausalLM,
EC2 Default User's avatar
EC2 Default User committed
13
    "bloom": BloomAWQForCausalLM,
14
    "gptj": GPTJAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
15
    "gpt_bigcode": GptBigCodeAWQForCausalLM,
twaka's avatar
twaka committed
16
    "mistral": MistralAWQForCausalLM,
17
    "mixtral": MixtralAWQForCausalLM,
twaka's avatar
twaka committed
18
    "gpt_neox": GPTNeoXAWQForCausalLM,
ldwang's avatar
ldwang committed
19
    "aquila": AquilaAWQForCausalLM,
Qing's avatar
Qing committed
20
    "Yi": YiAWQForCausalLM,
21
    "qwen": QwenAWQForCausalLM,
Aoyu's avatar
Aoyu committed
22
    "baichuan": BaichuanAWQForCausalLM,
23
    "llava": LlavaAWQForCausalLM,
24
25
}

26
27
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs)
28
29
30
31
32
33
34
35
36
37
38
    if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
        raise TypeError(f"{config.model_type} isn't supported yet.")
    model_type = config.model_type
    return model_type

class AutoAWQForCausalLM:
    def __init__(self):
        raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
                               'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
    
    @classmethod
Casper Hansen's avatar
Casper Hansen committed
39
40
    def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
                              device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
41
        model_type = check_and_get_model_type(model_path, trust_remote_code, **model_init_kwargs)
42
43

        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
Casper Hansen's avatar
Casper Hansen committed
44
45
            model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
            device_map=device_map, **model_init_kwargs
46
        )
47
48

    @classmethod
49
    def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
Casper Hansen's avatar
Casper Hansen committed
50
                       trust_remote_code=True, fuse_layers=True,
51
                       batch_size=1, safetensors=True,
52
                       device_map="balanced", offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
53
        os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
54
        model_type = check_and_get_model_type(quant_path, trust_remote_code)
55

56
        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
Casper Hansen's avatar
Casper Hansen committed
57
            quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, 
s4rduk4r's avatar
s4rduk4r committed
58
            fuse_layers=fuse_layers, safetensors=safetensors, 
59
            device_map=device_map, offload_folder=offload_folder,
60
            **config_kwargs
s4rduk4r's avatar
s4rduk4r committed
61
        )