"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "f3445d19fad8920538474fb2097d966a1d48709e"
Unverified Commit fea4d11d authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[HF] fix quantization config (#3039)

* Try fixing issue 3026 which is caused by the quantization_config argument introduced in Commit 758c5ed8

.
The argument is in Dict type, but for a GPTQ quantized model, it has a conflict with the huggingface interface which expects QuantizationConfigMixin type.
Current solution is removing quantization_config argument in HFLM._create_model() of lm_eval/models/huggingface.py.
Require further modification to restore the functionality provided by the previous commit.

* wrap quantization_config in AutoQuantizationConfig

* handle quantization config not dict

* wrap quantization_config in AutoQuantizationConfig if dict

---------
Co-authored-by: default avatarshanhx2000 <hs359@duke.edu>
parent 6b3f3f7e
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import os import os
from datetime import timedelta from datetime import timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import jinja2 import jinja2
import torch import torch
...@@ -17,8 +17,6 @@ from accelerate import ( ...@@ -17,8 +17,6 @@ from accelerate import (
from accelerate.utils import get_max_memory from accelerate.utils import get_max_memory
from huggingface_hub import HfApi from huggingface_hub import HfApi
from packaging import version from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm from tqdm import tqdm
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
...@@ -40,6 +38,9 @@ from lm_eval.models.utils import ( ...@@ -40,6 +38,9 @@ from lm_eval.models.utils import (
) )
if TYPE_CHECKING:
from transformers.quantizers import AutoQuantizationConfig
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -188,6 +189,13 @@ class HFLM(TemplateLM): ...@@ -188,6 +189,13 @@ class HFLM(TemplateLM):
add_bos_token=add_bos_token, add_bos_token=add_bos_token,
) )
if (
quantization_config := getattr(self.config, "quantization_config", None)
) is not None and isinstance(quantization_config, dict):
from transformers.quantizers import AutoQuantizationConfig
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
# if we passed `pretrained` as a string, initialize our model now # if we passed `pretrained` as a string, initialize our model now
if isinstance(pretrained, str): if isinstance(pretrained, str):
self._create_model( self._create_model(
...@@ -205,7 +213,7 @@ class HFLM(TemplateLM): ...@@ -205,7 +213,7 @@ class HFLM(TemplateLM):
autogptq=autogptq, autogptq=autogptq,
gptqmodel=gptqmodel, gptqmodel=gptqmodel,
gguf_file=gguf_file, gguf_file=gguf_file,
quantization_config=getattr(self.config, "quantization_config", None), quantization_config=quantization_config,
subfolder=subfolder, subfolder=subfolder,
**kwargs, **kwargs,
) )
...@@ -554,7 +562,7 @@ class HFLM(TemplateLM): ...@@ -554,7 +562,7 @@ class HFLM(TemplateLM):
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
quantization_config: Optional[Dict[str, Any]] = None, quantization_config: Optional["AutoQuantizationConfig"] = None,
subfolder: str = "", subfolder: str = "",
**kwargs, **kwargs,
) -> None: ) -> None:
...@@ -649,6 +657,9 @@ class HFLM(TemplateLM): ...@@ -649,6 +657,9 @@ class HFLM(TemplateLM):
) )
if peft: if peft:
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
if version.parse(PEFT_VERSION) < version.parse("0.4.0"): if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
raise AssertionError("load_in_4bit requires peft >= 0.4.0") raise AssertionError("load_in_4bit requires peft >= 0.4.0")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment