Unverified Commit 4b701e22 authored by gakada's avatar gakada Committed by GitHub
Browse files

Add load_in_4bit and fix peft loading (#556)

parent 01e0626f
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
import peft import peft
from peft import __version__ as PEFT_VERSION
from pathlib import Path from pathlib import Path
from typing import List, Mapping, NewType, Optional, Tuple, Union from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
...@@ -87,6 +88,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -87,6 +88,7 @@ class HuggingFaceAutoLM(BaseLM):
device: Optional[Union[int, str]] = "cuda", device: Optional[Union[int, str]] = "cuda",
peft: str = None, peft: str = None,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
): ):
...@@ -142,7 +144,10 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -142,7 +144,10 @@ class HuggingFaceAutoLM(BaseLM):
`adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft) `adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft)
load_in_8bit (bool, optional, defaults to False): load_in_8bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-8bit quantized model. See: If True, will convert the loaded model into mixed-8bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-8bit
load_in_4bit (bool, optional, defaults to False):
If True, will convert the loaded model into mixed-4bit quantized model. See:
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#load-a-large-model-in-4bit
trust_remote_code (bool, optional, defaults to False): trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model. If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False): gptq_use_triton (bool, optional, defaults to False):
...@@ -197,7 +202,6 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -197,7 +202,6 @@ class HuggingFaceAutoLM(BaseLM):
max_cpu_memory, max_cpu_memory,
offload_folder, offload_folder,
) )
model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model( self.model = self._create_auto_model(
pretrained=pretrained, pretrained=pretrained,
quantized=quantized, quantized=quantized,
...@@ -206,6 +210,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -206,6 +210,8 @@ class HuggingFaceAutoLM(BaseLM):
subfolder=subfolder, subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config), torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton, gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
**model_kwargs, **model_kwargs,
) )
# note: peft_path can be different than pretrained model path # note: peft_path can be different than pretrained model path
...@@ -215,8 +221,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -215,8 +221,7 @@ class HuggingFaceAutoLM(BaseLM):
peft=peft, peft=peft,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config), load_in_4bit=load_in_4bit,
**model_kwargs,
) )
self.model.eval() self.model.eval()
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -241,12 +246,18 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -241,12 +246,18 @@ class HuggingFaceAutoLM(BaseLM):
max_memory: Optional[dict] = None, max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None, offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
) -> transformers.AutoModel: ) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration.""" """Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized: if not quantized:
if load_in_4bit:
assert transformers.__version__ >= "4.30.0", "load_in_4bit requires transformers >= 4.30.0"
model_kwargs = {}
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
model = self.AUTO_MODEL_CLASS.from_pretrained( model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
...@@ -256,6 +267,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -256,6 +267,7 @@ class HuggingFaceAutoLM(BaseLM):
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs,
) )
else: else:
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM
...@@ -278,23 +290,14 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -278,23 +290,14 @@ class HuggingFaceAutoLM(BaseLM):
peft: str, peft: str,
revision: str, revision: str,
subfolder: str, subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None, load_in_4bit: Optional[bool] = False,
max_memory: Optional[dict] = None,
offload_folder: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
): ):
if load_in_4bit:
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
model = self.AUTO_PEFT_CLASS.from_pretrained( model = self.AUTO_PEFT_CLASS.from_pretrained(
model, model,
peft, peft,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
) )
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment