"vscode:/vscode.git/clone" did not exist on "6c34d6339c040628e895d167cf22f2ab7104f8b3"
Unverified Commit 095d8406 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #519 from gakada/gptq

Add support for loading GPTQ models via AutoGPTQ
parents 8cff2bea c11ad4f2
......@@ -7,7 +7,7 @@ This project provides a unified framework to test generative language models on
Features:
- 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list.
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
- Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/).
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
- Evaluating with publicly available prompts ensures reproducibility and comparability between papers.
......@@ -29,6 +29,12 @@ To install additional multilingual tokenization and text segmentation packages,
pip install -e ".[multilingual]"
```
To support loading GPTQ quantized models, install the package with the `auto-gptq` extra:
```bash
pip install -e ".[auto-gptq]"
```
## Basic Usage
> **Note**: When reporting results from eval harness, please include the task versions (shown in `results["versions"]`) for reproducibility. This allows bug fixes to tasks while also ensuring that previously reported scores are reproducible. See the [Task Versioning](#task-versioning) section for more info.
......@@ -111,6 +117,14 @@ python main.py \
--device cuda:0
```
GPTQ quantized models can be loaded by specifying their file names in `,quantized=NAME` (or `,quantized=True` for default names) in the `model_args` argument:
```bash
python main.py \
--model hf-causal-experimental \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \
--tasks hellaswag
```
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
......
......@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F
import transformers
import peft
from pathlib import Path
from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
......@@ -69,6 +70,7 @@ class HuggingFaceAutoLM(BaseLM):
def __init__(
self,
pretrained: str,
quantized: Optional[Union[True, str]] = None,
tokenizer: Optional[str] = None,
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
......@@ -86,6 +88,7 @@ class HuggingFaceAutoLM(BaseLM):
peft: str = None,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args:
......@@ -93,6 +96,9 @@ class HuggingFaceAutoLM(BaseLM):
The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API.
quantized (str or True, optional, defaults to None):
File name of a GPTQ quantized model to load. Set to `True` to use the
default name of the quantized model.
add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and
......@@ -139,6 +145,8 @@ class HuggingFaceAutoLM(BaseLM):
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit
trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
"""
super().__init__()
......@@ -192,10 +200,12 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model(
pretrained=pretrained,
quantized=quantized,
trust_remote_code=trust_remote_code,
revision=revision,
subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
**model_kwargs,
)
# note: peft_path can be different than pretrained model path
......@@ -224,6 +234,7 @@ class HuggingFaceAutoLM(BaseLM):
self,
*,
pretrained: str,
quantized: Optional[Union[True, str]] = None,
revision: str,
subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None,
......@@ -232,18 +243,32 @@ class HuggingFaceAutoLM(BaseLM):
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration."""
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
if quantized is None:
model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
device_map=device_map,
max_memory=max_memory,
offload_folder=offload_folder,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
else:
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if quantized == True else Path(quantized).stem,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
use_safetensors=True if quantized == True else quantized.endswith('.safetensors'),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
)
return model
def _create_auto_model_peft(
......
......@@ -44,5 +44,6 @@ setuptools.setup(
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"auto-gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment