Commit 990bc542 authored by ynot's avatar ynot
Browse files

Add - 4bit-related args

parent f862a118
......@@ -91,6 +91,8 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
......@@ -152,6 +154,13 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
"""
super().__init__()
......@@ -215,6 +224,8 @@ class HuggingFaceAutoLM(BaseLM):
gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
......@@ -253,6 +264,8 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized:
......@@ -261,6 +274,9 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
model_kwargs["bnb_4bit_compute_dtype"] = getattr(torch, bnb_4bit_compute_dtype)
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment