Unverified Commit 4b701e22 authored by gakada's avatar gakada Committed by GitHub
Browse files

Add load_in_4bit and fix peft loading (#556)

parent 01e0626f
......@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
import transformers
import peft
from peft import __version__ as PEFT_VERSION
from pathlib import Path
from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
......@@ -87,6 +88,7 @@ class HuggingFaceAutoLM(BaseLM):
device: Optional[Union[int, str]] = "cuda",
peft: str = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
):
......@@ -142,7 +144,10 @@ class HuggingFaceAutoLM(BaseLM):
`adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft)
load_in_8bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-8bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-8bit
load_in_4bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-4bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-4bit
trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
......@@ -197,7 +202,6 @@ class HuggingFaceAutoLM(BaseLM):
max_cpu_memory,
offload_folder,
)
model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model(
pretrained=pretrained,
quantized=quantized,
......@@ -206,6 +210,8 @@ class HuggingFaceAutoLM(BaseLM):
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
......@@ -215,8 +221,7 @@ class HuggingFaceAutoLM(BaseLM):
peft=peft,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
**model_kwargs,
load_in_4bit=load_in_4bit,
)
self.model.eval()
torch.set_grad_enabled(False)
......@@ -241,12 +246,18 @@ class HuggingFaceAutoLM(BaseLM):
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized:
if load_in_4bit:
assert transformers.__version__ >= "4.30.0", "load_in_4bit requires transformers >= 4.30.0"
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
......@@ -256,6 +267,7 @@ class HuggingFaceAutoLM(BaseLM):
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
**model_kwargs,
)
else:
from auto_gptq import AutoGPTQForCausalLM
......@@ -278,23 +290,14 @@ class HuggingFaceAutoLM(BaseLM):
peft: str,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
load_in_4bit: Optional[bool] = False,
):
if load_in_4bit:
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
model = self.AUTO_PEFT_CLASS.from_pretrained(
model,
peft,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment