Unverified Commit 59aef189 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #644 from gakada/bits

Add PEFT, quantization, remote code, LLaMA fix
parents 85ebe408 1d8f782f
...@@ -44,10 +44,10 @@ To install additional multilingual tokenization and text segmentation packages, ...@@ -44,10 +44,10 @@ To install additional multilingual tokenization and text segmentation packages,
pip install -e ".[multilingual]" pip install -e ".[multilingual]"
``` ```
To support loading GPTQ quantized models, install the package with the `auto-gptq` extra: To support loading GPTQ quantized models, install the package with the `gptq` extra:
```bash ```bash
pip install -e ".[auto-gptq]" pip install -e ".[gptq]"
``` ```
## Basic Usage ## Basic Usage
...@@ -160,17 +160,17 @@ For models loaded with the HuggingFace `transformers` library, any arguments pr ...@@ -160,17 +160,17 @@ For models loaded with the HuggingFace `transformers` library, any arguments pr
```bash ```bash
python main.py \ python main.py \
--model hf \ --model hf \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \ --model_args pretrained=EleutherAI/gpt-j-6b,parallelize=True,load_in_4bit=True,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \ --tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0 --device cuda:0
``` ```
GPTQ quantized models can be loaded by specifying their file names in `,quantized=NAME` (or `,quantized=True` for default names) in the `model_args` argument: [GPTQ](https://github.com/PanQiWei/AutoGPTQ) quantized models can be loaded by specifying their file names in `,gptq=NAME` (or `,gptq=True` for default names) in the `model_args` argument:
```bash ```bash
python main.py \ python main.py \
--model hf \ --model hf \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \ --model_args pretrained=model-name-or-path,gptq=model.safetensors,gptq_use_triton=True \
--tasks hellaswag --tasks hellaswag
``` ```
......
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from peft import __version__ as PEFT_VERSION, PeftModel
import copy import copy
from collections import defaultdict from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F import torch.nn.functional as F
...@@ -58,15 +60,16 @@ class HFLM(LM): ...@@ -58,15 +60,16 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
device="cuda", pretrained: Optional[str] = "gpt2",
pretrained="gpt2", revision: Optional[str] = "main",
revision="main", subfolder: Optional[str] = None,
low_cpu_mem_usage=None, tokenizer: Optional[str] = None,
max_length=None, max_length: Optional[int] = None,
subfolder=None, device: Optional[str] = "cuda",
tokenizer=None,
batch_size=1,
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[int] = 1,
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
...@@ -74,6 +77,14 @@ class HFLM(LM): ...@@ -74,6 +77,14 @@ class HFLM(LM):
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload", offload_folder: Optional[str] = "./offload",
# PEFT and quantization options
peft: Optional[str] = None,
load_in_8bit: Optional[bool] = False,
load_in_4bit: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False,
): ):
super().__init__() super().__init__()
...@@ -117,10 +128,10 @@ class HFLM(LM): ...@@ -117,10 +128,10 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# get config
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code,
) )
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
...@@ -133,13 +144,56 @@ class HFLM(LM): ...@@ -133,13 +144,56 @@ class HFLM(LM):
transformers.AutoModelForSeq2SeqLM, transformers.AutoModelForSeq2SeqLM,
] ]
self._model = self.AUTO_MODEL_CLASS.from_pretrained( if not gptq:
pretrained, if load_in_4bit:
revision=revision, assert (
low_cpu_mem_usage=low_cpu_mem_usage, transformers.__version__ >= "4.30.0"
**model_kwargs, ), "load_in_4bit requires transformers >= 4.30.0"
torch_dtype=utils.get_dtype(dtype), if transformers.__version__ >= "4.30.0":
) model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
if bnb_4bit_quant_type:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
if bnb_4bit_compute_dtype:
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
bnb_4bit_compute_dtype
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision,
torch_dtype=utils.get_dtype(dtype),
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code,
load_in_8bit=load_in_8bit,
**model_kwargs,
)
else:
try:
from auto_gptq import AutoGPTQForCausalLM
except ModuleNotFoundError:
raise Exception(
"Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
)
self._model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if gptq is True else Path(gptq).stem,
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code,
use_safetensors=True if gptq is True else gptq.endswith(".safetensors"),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
**model_kwargs,
)
if peft:
if load_in_4bit:
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
self._model = PeftModel.from_pretrained(
self._model, peft, revision=revision
)
# forever after, access self._model through self.model property # forever after, access self._model through self.model property
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
...@@ -150,6 +204,7 @@ class HFLM(LM): ...@@ -150,6 +204,7 @@ class HFLM(LM):
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code,
) )
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
...@@ -361,16 +416,27 @@ class HFLM(LM): ...@@ -361,16 +416,27 @@ class HFLM(LM):
return logits return logits
def _encode_pair(self, context, continuation):
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc = [self.eot_token_id] context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else: else:
context_enc = self.tok_encode(context) context_enc, continuation_enc = self._encode_pair(context, continuation)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
...@@ -442,7 +508,6 @@ class HFLM(LM): ...@@ -442,7 +508,6 @@ class HFLM(LM):
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size, self.batch_size,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
...@@ -544,7 +609,6 @@ class HFLM(LM): ...@@ -544,7 +609,6 @@ class HFLM(LM):
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
......
...@@ -55,7 +55,7 @@ setuptools.setup( ...@@ -55,7 +55,7 @@ setuptools.setup(
"promptsource": [ "promptsource": [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource" "promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
], ],
"auto-gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"], "gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"],
"anthropic": ["anthropic"], "anthropic": ["anthropic"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment