auto.py 3.63 KB
Newer Older
1
import os
Casper's avatar
Casper committed
2
import logging
3
from transformers import AutoConfig
4
from awq.models import *
5
from awq.models.base import BaseAWQForCausalLM
6
7
8

AWQ_CAUSAL_LM_MODEL_MAP = {
    "mpt": MptAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
9
    "llama": LlamaAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
10
11
    "opt": OptAWQForCausalLM,
    "RefinedWeb": FalconAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
12
    "RefinedWebModel": FalconAWQForCausalLM,
Casper's avatar
Casper committed
13
    "falcon": FalconAWQForCausalLM,
EC2 Default User's avatar
EC2 Default User committed
14
    "bloom": BloomAWQForCausalLM,
15
    "gptj": GPTJAWQForCausalLM,
Casper Hansen's avatar
Casper Hansen committed
16
    "gpt_bigcode": GptBigCodeAWQForCausalLM,
twaka's avatar
twaka committed
17
    "mistral": MistralAWQForCausalLM,
18
    "mixtral": MixtralAWQForCausalLM,
twaka's avatar
twaka committed
19
    "gpt_neox": GPTNeoXAWQForCausalLM,
ldwang's avatar
ldwang committed
20
    "aquila": AquilaAWQForCausalLM,
Qing's avatar
Qing committed
21
    "Yi": YiAWQForCausalLM,
22
    "qwen": QwenAWQForCausalLM,
Aoyu's avatar
Aoyu committed
23
    "baichuan": BaichuanAWQForCausalLM,
24
    "llava": LlavaAWQForCausalLM,
Casper's avatar
Casper committed
25
    "qwen2": Qwen2AWQForCausalLM,
TechxGenus's avatar
TechxGenus committed
26
    "gemma": GemmaAWQForCausalLM,
Isotr0py's avatar
Isotr0py committed
27
    "stablelm": StableLmAWQForCausalLM,
少年's avatar
少年 committed
28
    "starcoder2": Starcoder2AWQForCausalLM,
29
30
}

31

32
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
33
34
35
    config = AutoConfig.from_pretrained(
        model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs
    )
36
37
38
39
40
    if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
        raise TypeError(f"{config.model_type} isn't supported yet.")
    model_type = config.model_type
    return model_type

41

42
43
class AutoAWQForCausalLM:
    def __init__(self):
44
45
46
47
48
        raise EnvironmentError(
            "You must instantiate AutoAWQForCausalLM with\n"
            "AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained"
        )

49
    @classmethod
50
51
52
53
    def from_pretrained(
        self,
        model_path,
        trust_remote_code=True,
Casper's avatar
Casper committed
54
        safetensors=True,
55
        device_map=None,
56
        download_kwargs=None,
57
58
59
60
61
        **model_init_kwargs,
    ) -> BaseAWQForCausalLM:
        model_type = check_and_get_model_type(
            model_path, trust_remote_code, **model_init_kwargs
        )
62
63

        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
64
65
66
67
68
            model_path,
            model_type,
            trust_remote_code=trust_remote_code,
            safetensors=safetensors,
            device_map=device_map,
69
            download_kwargs=download_kwargs,
70
            **model_init_kwargs,
71
        )
72
73

    @classmethod
74
75
76
77
    def from_quantized(
        self,
        quant_path,
        quant_filename="",
Casper's avatar
Casper committed
78
        max_seq_len=2048,
79
80
81
82
83
84
85
        trust_remote_code=True,
        fuse_layers=True,
        use_exllama=False,
        use_exllama_v2=False,
        batch_size=1,
        safetensors=True,
        device_map="balanced",
86
        max_memory=None,
87
        offload_folder=None,
88
        download_kwargs=None,
89
90
        **config_kwargs,
    ) -> BaseAWQForCausalLM:
91
        os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
92
        model_type = check_and_get_model_type(quant_path, trust_remote_code)
93

Casper's avatar
Casper committed
94
95
96
97
98
99
100
        if config_kwargs.get("max_new_tokens") is not None:
            max_seq_len = config_kwargs["max_new_tokens"]
            logging.warning(
                "max_new_tokens argument is deprecated... gracefully "
                "setting max_seq_len=max_new_tokens."
            )

101
        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
102
103
104
            quant_path,
            model_type,
            quant_filename,
Casper's avatar
Casper committed
105
            max_seq_len,
106
107
108
109
110
111
            trust_remote_code=trust_remote_code,
            fuse_layers=fuse_layers,
            use_exllama=use_exllama,
            use_exllama_v2=use_exllama_v2,
            safetensors=safetensors,
            device_map=device_map,
112
            max_memory=max_memory,
113
            offload_folder=offload_folder,
114
            download_kwargs=download_kwargs,
115
            **config_kwargs,
s4rduk4r's avatar
s4rduk4r committed
116
        )