Commit 990bc542 authored by ynot's avatar ynot
Browse files

Add - 4bit-related args

parent f862a118
...@@ -91,6 +91,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -91,6 +91,8 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False, load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
): ):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args: Args:
...@@ -152,6 +154,13 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -152,6 +154,13 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model. If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False): gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference. Use Triton for GPTQ inference.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
""" """
super().__init__() super().__init__()
...@@ -215,6 +224,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -215,6 +224,8 @@ class HuggingFaceAutoLM(BaseLM):
gptq_use_triton=gptq_use_triton, gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
**model_kwargs, **model_kwargs,
) )
# note: peft_path can be different than pretrained model path # note: peft_path can be different than pretrained model path
...@@ -253,6 +264,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -253,6 +264,8 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel: ) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration.""" """Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized: if not quantized:
...@@ -261,6 +274,9 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -261,6 +274,9 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs = {} model_kwargs = {}
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
model_kwargs["bnb_4bit_compute_dtype"] = getattr(torch, bnb_4bit_compute_dtype)
model = self.AUTO_MODEL_CLASS.from_pretrained( model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment