auto.py 6.3 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from inspect import signature
from typing import Dict, Optional, Union

from ._base import BaseGPTQForCausalLM, BaseQuantizeConfig
from ._utils import check_and_get_model_type
from .baichuan import BaiChuanGPTQForCausalLM
from .bloom import BloomGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
from .cohere import CohereGPTQForCausalLM
from .decilm import DeciLMGPTQForCausalLM
from .gemma import GemmaGPTQForCausalLM
from .gemma2 import Gemma2GPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .longllama import LongLlamaGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mixtral import MixtralGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .phi import PhiGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .qwen2 import Qwen2GPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .stablelmepoch import StableLMEpochGPTQForCausalLM
from .starcoder2 import Starcoder2GPTQForCausalLM
from .xverse import XverseGPTQForCausalLM
from .yi import YiGPTQForCausalLM


GPTQ_CAUSAL_LM_MODEL_MAP = {
    "bloom": BloomGPTQForCausalLM,
    "gpt_neox": GPTNeoXGPTQForCausalLM,
    "gptj": GPTJGPTQForCausalLM,
    "gpt2": GPT2GPTQForCausalLM,
    "llama": LlamaGPTQForCausalLM,
    "opt": OPTGPTQForCausalLM,
    "moss": MOSSGPTQForCausalLM,
    "gpt_bigcode": GPTBigCodeGPTQForCausalLM,
    "codegen": CodeGenGPTQForCausalLM,
    "cohere": CohereGPTQForCausalLM,
    "RefinedWebModel": RWGPTQForCausalLM,
    "RefinedWeb": RWGPTQForCausalLM,
    "falcon": RWGPTQForCausalLM,
    "baichuan": BaiChuanGPTQForCausalLM,
    "internlm": InternLMGPTQForCausalLM,
    "qwen": QwenGPTQForCausalLM,
    "mistral": MistralGPTQForCausalLM,
    "Yi": YiGPTQForCausalLM,
    "xverse": XverseGPTQForCausalLM,
    "deci": DeciLMGPTQForCausalLM,
    "stablelm_epoch": StableLMEpochGPTQForCausalLM,
    "starcoder2": Starcoder2GPTQForCausalLM,
    "mixtral": MixtralGPTQForCausalLM,
    "qwen2": Qwen2GPTQForCausalLM,
    "longllama": LongLlamaGPTQForCausalLM,
    "gemma": GemmaGPTQForCausalLM,
    "gemma2": Gemma2GPTQForCausalLM,
    "phi": PhiGPTQForCausalLM,
    "mpt": MPTGPTQForCausalLM,
}


class AutoGPTQForCausalLM:
    def __init__(self):
        raise EnvironmentError(
            "AutoGPTQModelForCausalLM is designed to be instantiated\n"
            "using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n"
            "using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model."
        )

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        quantize_config: BaseQuantizeConfig,
        max_memory: Optional[dict] = None,
        trust_remote_code: bool = False,
        **model_init_kwargs,
    ) -> BaseGPTQForCausalLM:
        model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code)
        return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
            pretrained_model_name_or_path=pretrained_model_name_or_path,
            quantize_config=quantize_config,
            max_memory=max_memory,
            trust_remote_code=trust_remote_code,
            **model_init_kwargs,
        )

    @classmethod
    def from_quantized(
        cls,
        model_name_or_path: Optional[str],
        device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None,
        max_memory: Optional[dict] = None,
        device: Optional[Union[str, int]] = None,
        low_cpu_mem_usage: bool = False,
        use_triton: bool = False,
        inject_fused_attention: bool = False,
        inject_fused_mlp: bool = False,
        use_cuda_fp16: bool = True,
        quantize_config: Optional[BaseQuantizeConfig] = None,
        model_basename: Optional[str] = None,
        use_safetensors: bool = True,
        trust_remote_code: bool = False,
        warmup_triton: bool = False,
        trainable: bool = False,
        disable_exllama: Optional[bool] = None,
        disable_exllamav2: bool = False,
        use_marlin: bool = False,
        use_tritonv2: bool = False,
        **kwargs,
    ) -> BaseGPTQForCausalLM:
        # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
        if disable_exllama is None:
            if disable_exllamav2:
                disable_exllama = False
            else:
                disable_exllama = True

        model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
        quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
        # A static list of kwargs needed for huggingface_hub
        huggingface_kwargs = [
            "cache_dir",
            "force_download",
            "proxies",
            "resume_download",
            "local_files_only",
            "use_auth_token",
            "revision",
            "subfolder",
            "_raise_exceptions_for_missing_entries",
            "_commit_hash",
        ]
        # TODO: do we need this filtering of kwargs? @PanQiWei is there a reason we can't just pass all kwargs?
        keywords = {
            key: kwargs[key]
            for key in list(signature(quant_func).parameters.keys()) + huggingface_kwargs
            if key in kwargs
        }
        return quant_func(
            model_name_or_path=model_name_or_path,
            device_map=device_map,
            max_memory=max_memory,
            device=device,
            low_cpu_mem_usage=low_cpu_mem_usage,
            use_triton=use_triton,
            inject_fused_attention=inject_fused_attention,
            inject_fused_mlp=inject_fused_mlp,
            use_cuda_fp16=use_cuda_fp16,
            quantize_config=quantize_config,
            model_basename=model_basename,
            use_safetensors=use_safetensors,
            trust_remote_code=trust_remote_code,
            warmup_triton=warmup_triton,
            trainable=trainable,
            disable_exllama=disable_exllama,
            disable_exllamav2=disable_exllamav2,
            use_marlin=use_marlin,
            use_tritonv2=use_tritonv2,
            **keywords,
        )


__all__ = ["AutoGPTQForCausalLM"]