Commit 436edcc9 authored by codehound42's avatar codehound42
Browse files

Add `inject_fused_attention` parameter for `auto-gptq`

parent fd1c7196
......@@ -94,6 +94,7 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False,
......@@ -160,6 +161,8 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
inject_fused_attention (bool, optional, defaults to True):
Inject fused attention.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
......@@ -233,6 +236,7 @@ class HuggingFaceAutoLM(BaseLM):
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
......@@ -280,6 +284,7 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False,
......@@ -321,6 +326,7 @@ class HuggingFaceAutoLM(BaseLM):
use_safetensors=True if quantized == True else quantized.endswith('.safetensors'),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
)
return model
......@@ -775,4 +781,4 @@ def stop_sequences_criteria(
for sequence in stop_sequences
],
]
)
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment