_config.py 3.69 KB
Newer Older
Casper's avatar
Casper committed
1
2
import os
import json
3
from typing import Dict, Optional, List
Casper's avatar
Casper committed
4
from dataclasses import dataclass, field
Casper's avatar
Casper committed
5
6
from transformers.utils.hub import PushToHubMixin, cached_file

Casper's avatar
Casper committed
7

Casper's avatar
Casper committed
8
9
10
11
12
13
@dataclass
class AwqConfig(PushToHubMixin):
    quant_method: str = field(default="awq")
    zero_point: bool = field(default=True)
    q_group_size: int = field(default=128)
    w_bit: int = field(default=4)
Casper's avatar
Casper committed
14
15
    version: str = field(default="gemm")
    config_file_name = "config.json"
16
    modules_to_not_convert: Optional[List] = None
Casper's avatar
Casper committed
17
18

    @classmethod
Casper's avatar
Casper committed
19
    def from_dict(cls, quant_config: Dict = {}):
Casper's avatar
Casper committed
20
21
22
23
        if not quant_config:
            quant_config = cls()
        else:
            quant_config = cls(**quant_config)
Casper's avatar
Casper committed
24
25
            quant_config.version = quant_config.version.lower()

Casper's avatar
Casper committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        return quant_config

    @classmethod
    def from_pretrained(cls, save_dir: str, **kwargs):
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        resume_download = kwargs.pop("resume_download", False)
        proxies = kwargs.pop("proxies", None)
        local_files_only = kwargs.pop("local_files_only", False)
        use_auth_token = kwargs.pop("use_auth_token", None)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)
        commit_hash = kwargs.pop("_commit_hash", None)

        if os.path.isdir(save_dir):  # Local
            resolved_config_file = os.path.join(save_dir, cls.config_file_name)
Casper's avatar
Casper committed
42
        else:  # Remote
Casper's avatar
Casper committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            resolved_config_file = cached_file(
                save_dir,
                cls.config_file_name,
                cache_dir=cache_dir,
                force_download=force_download,
                resume_download=resume_download,
                proxies=proxies,
                use_auth_token=use_auth_token,
                revision=revision,
                local_files_only=local_files_only,
                subfolder=subfolder,
                _raise_exceptions_for_missing_entries=False,
                _raise_exceptions_for_connection_errors=False,
                _commit_hash=commit_hash,
            )
Casper's avatar
Casper committed
58
59

        quant_config = None
Casper's avatar
Casper committed
60
        if os.path.exists(resolved_config_file):
Casper's avatar
Casper committed
61
            with open(resolved_config_file, "r", encoding="utf-8") as file:
Casper's avatar
Casper committed
62
                loaded_config = json.loads(file.read())
Casper's avatar
Casper committed
63
64
65
66
67
68
69
70

            quant_config = loaded_config.get("quantization_config")

            if quant_config is not None:
                awq_config = cls.from_transformers_dict(cls, quant_config)
                quant_config = cls(**awq_config)

        if quant_config is None:
Casper's avatar
Casper committed
71
            quant_config = cls()
Casper's avatar
Casper committed
72

Casper's avatar
Casper committed
73
74
75
76
77
78
79
        return quant_config

    def to_dict(self):
        return {
            "zero_point": self.zero_point,
            "q_group_size": self.q_group_size,
            "w_bit": self.w_bit,
80
81
            "version": self.version,
            "modules_to_not_convert": self.modules_to_not_convert,
Casper's avatar
Casper committed
82
83
84
85
86
87
88
89
90
        }

    def to_transformers_dict(self):
        return {
            "quant_method": self.quant_method,
            "zero_point": self.zero_point,
            "group_size": self.q_group_size,
            "bits": self.w_bit,
            "version": self.version.lower(),
91
            "modules_to_not_convert": self.modules_to_not_convert,
Casper's avatar
Casper committed
92
        }
Casper's avatar
Casper committed
93
94
95
96
97
98
99
100
101
102

    def from_transformers_dict(self, transformers_dict: Dict):
        return {
            "quant_method": transformers_dict.get("quant_method"),
            "zero_point": transformers_dict.get("zero_point"),
            "q_group_size": transformers_dict.get("group_size"),
            "w_bit": transformers_dict.get("bits"),
            "version": transformers_dict.get("version"),
            "modules_to_not_convert": transformers_dict.get("modules_to_not_convert"),
        }