import_utils.py 3.81 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from logging import getLogger
from typing import Optional

import torch
from packaging.version import parse as parse_version


try:
    import triton  # noqa: F401

    TRITON_AVAILABLE = True
except ImportError:
    TRITON_AVAILABLE = False

try:
    import autogptq_cuda_64  # noqa: F401

    AUTOGPTQ_CUDA_AVAILABLE = True
except Exception:
    AUTOGPTQ_CUDA_AVAILABLE = False


try:
    import exllama_kernels  # noqa: F401

    EXLLAMA_KERNELS_AVAILABLE = True
except Exception:
    EXLLAMA_KERNELS_AVAILABLE = False

try:
    import exllamav2_kernels  # noqa: F401

    EXLLAMAV2_KERNELS_AVAILABLE = True
except Exception:
    EXLLAMAV2_KERNELS_AVAILABLE = False

try:
    import cQIGen  # noqa: F401

    QIGEN_AVAILABLE = True
    QIGEN_EXCEPTION = None
except Exception as e:
    QIGEN_AVAILABLE = False
    QIGEN_EXCEPTION = e

try:
    import autogptq_marlin_cuda  # noqa: F401

    MARLIN_AVAILABLE = True
    MARLIN_EXCEPTION = None
except Exception as e:
    MARLIN_AVAILABLE = False
    MARLIN_EXCEPTION = e


logger = getLogger(__name__)


def dynamically_import_QuantLinear(
    use_triton: bool,
    desc_act: bool,
    group_size: int,
    bits: int,
    disable_exllama: Optional[bool] = None,
    disable_exllamav2: bool = False,
    use_qigen: bool = False,
    use_marlin: bool = False,
    use_tritonv2: bool = False,
):
    try:
        import habana_frameworks.torch.hpu  # noqa: F401
    except Exception as e:
        pass
    else:
        from ..nn_modules.qlinear.qlinear_hpu import QuantLinear
        return QuantLinear
    if use_qigen:
        if not QIGEN_AVAILABLE:
            raise ValueError(
                f"QIGen appears to be not available with the error: {QIGEN_EXCEPTION}. Please check your installation or use `use_qigen=False`."
            )
        from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
    else:
        if use_triton or use_tritonv2:
            if torch.version.hip:
                logger.warning(
                    "Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False."
                )
            if use_tritonv2:
                logger.debug("Using tritonv2 for GPTQ")
                from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
            else:
                from ..nn_modules.qlinear.qlinear_triton import QuantLinear
        else:
            # If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
            if disable_exllama is None:
                if disable_exllamav2:
                    disable_exllama = False
                else:
                    disable_exllama = True
            if bits == 4 and use_marlin:
                from ..nn_modules.qlinear.qlinear_marlin import QuantLinear
            elif bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
                from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
            elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
                from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
            elif not desc_act or group_size == -1:
                from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
            else:
                from ..nn_modules.qlinear.qlinear_cuda import QuantLinear

    return QuantLinear


def compare_transformers_version(version: str = "v4.28.0", op: str = "eq"):
    assert op in ["eq", "lt", "le", "gt", "ge"]

    from transformers import __version__

    return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))


def compare_pytorch_version(version: str = "v2.0.0", op: str = "eq"):
    assert op in ["eq", "lt", "le", "gt", "ge"]

    from torch import __version__

    return getattr(parse_version(__version__), f"__{op}__")(parse_version(version))