Commit 436edcc9 authored by codehound42's avatar codehound42
Browse files

Add `inject_fused_attention` parameter for `auto-gptq`

parent fd1c7196
...@@ -94,6 +94,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -94,6 +94,7 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False, load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None, bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False, bnb_4bit_use_double_quant: Optional[bool] = False,
...@@ -160,6 +161,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -160,6 +161,8 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model. If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False): gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference. Use Triton for GPTQ inference.
inject_fused_attention (bool, optional, defaults to True):
Inject fused attention.
bnb_4bit_quant_type (str, optional, defaults to None): bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See: The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77 https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
...@@ -233,6 +236,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -233,6 +236,7 @@ class HuggingFaceAutoLM(BaseLM):
subfolder=subfolder, subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config), torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton, gptq_use_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_quant_type=bnb_4bit_quant_type,
...@@ -280,6 +284,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -280,6 +284,7 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
inject_fused_attention: Optional[bool] = True,
bnb_4bit_quant_type: Optional[str] = None, bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None, bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
bnb_4bit_use_double_quant: Optional[bool] = False, bnb_4bit_use_double_quant: Optional[bool] = False,
...@@ -321,6 +326,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -321,6 +326,7 @@ class HuggingFaceAutoLM(BaseLM):
use_safetensors=True if quantized == True else quantized.endswith('.safetensors'), use_safetensors=True if quantized == True else quantized.endswith('.safetensors'),
use_triton=gptq_use_triton, use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton, warmup_triton=gptq_use_triton,
inject_fused_attention=inject_fused_attention,
) )
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment