Commit e6960b9a authored by svenhendrikx's avatar svenhendrikx
Browse files

Merge branch 'master' into instantiate-model-from-Automodel

parents ddc634f2 b281b092
* @jon-tow @StellaAthena @haileyschoelkopf @lintangsutawika * @haileyschoelkopf @lintangsutawika
# Language Model Evaluation Harness # Language Model Evaluation Harness
## Notice to Users
(as of 6/15/23)
We have a revamp of the Evaluation Harness library internals staged on the [big-refactor](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) branch! It is far along in progress, but before we start to move the `master` branch of the repository over to this new design with a new version release, we'd like to ensure that it's been tested by outside users and there are no glaring bugs.
We’d like your help to test it out! you can help by:
1. Trying out your current workloads on the big-refactor branch, and seeing if anything breaks or is counterintuitive,
2. Porting tasks supported in the previous version of the harness to the new YAML configuration format. Please check out our [task implementation guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md) for more information.
If you choose to port a task not yet completed according to [our checklist](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/README.md), then you can contribute it by opening a PR containing [Refactor] in the name with:
- A shell command to run the task in the `master` branch, and what the score is
- A shell command to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
Lastly, we'll no longer be accepting new feature requests beyond those that are already open to the master branch as we carry out this switch to the new version over the next week, though we will be accepting bugfixes to `master` branch and PRs to `big-refactor`. Feel free to reach out in the #lm-thunderdome channel of the EAI discord for more information.
## Overview ## Overview
This project provides a unified framework to test generative language models on a large number of different evaluation tasks. This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
......
...@@ -119,6 +119,12 @@ class LM(abc.ABC): ...@@ -119,6 +119,12 @@ class LM(abc.ABC):
class BaseLM(LM): class BaseLM(LM):
def __init__(self):
super().__init__()
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = 512
@property @property
@abstractmethod @abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -167,6 +173,26 @@ class BaseLM(LM): ...@@ -167,6 +173,26 @@ class BaseLM(LM):
""" """
pass pass
def _detect_batch_size(self, requests=None, pos=0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
else:
max_length = self.max_length
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
utils.clear_torch_cache()
return batch_size
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length. # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow # TODO: enforce this somehow
...@@ -202,19 +228,7 @@ class BaseLM(LM): ...@@ -202,19 +228,7 @@ class BaseLM(LM):
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
...@@ -267,34 +281,24 @@ class BaseLM(LM): ...@@ -267,34 +281,24 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
reordered_requests = re_ord.get_reordered()
n_reordered_requests = len(reordered_requests)
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
if len(re_ord.get_reordered()) > 0: def _batch_scheduler(pos):
_, context_enc, continuation_enc = re_ord.get_reordered()[0] sched = pos // int(n_reordered_requests / self.batch_schedule)
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) if sched in self.batch_sizes:
if (self.batch_size == 'auto'): return self.batch_sizes[sched]
print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size")
if override_bs is None: self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
print('Passed argument batch_size = auto. Detecting largest batch size') print(f"Determined largest batch size: {self.batch_sizes[sched]}")
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again return self.batch_sizes[sched]
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size
else:
adaptive_batch_size = override_bs
else:
adaptive_batch_size = 0 if override_bs is None else override_bs
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), tqdm(reordered_requests, disable=disable_tqdm),
self.batch_size if self.batch_size != "auto" else adaptive_batch_size, n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 else None,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
...@@ -864,6 +868,10 @@ class CachingLM: ...@@ -864,6 +868,10 @@ class CachingLM:
lm.set_cache_hook(self.get_cache_hook()) lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr): def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests): def fn(requests):
res = [] res = []
remaining_reqs = [] remaining_reqs = []
......
---
dataset_info:
features:
- name: question_id
dtype: string
- name: question_source
dtype: string
- name: question
dtype: string
- name: answer
struct:
- name: aliases
sequence: string
- name: value
dtype: string
- name: search_results
sequence:
- name: description
dtype: string
- name: filename
dtype: string
- name: rank
dtype: int32
- name: title
dtype: string
- name: url
dtype: string
- name: search_context
dtype: string
config_name: triviaqa
splits:
- name: train
num_bytes: 1270894387
num_examples: 87622
- name: validation
num_bytes: 163755044
num_examples: 11313
download_size: 632549060
dataset_size: 1434649431
---
{"triviaqa": {"description": "TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence\ntriples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts\nand independently gathered evidence documents, six per question on average, that provide\nhigh quality distant supervision for answering the questions.\n", "citation": "@InProceedings{JoshiTriviaQA2017,\n author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},\n title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},\n booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},\n month = {July},\n year = {2017},\n address = {Vancouver, Canada},\n publisher = {Association for Computational Linguistics},\n}\n", "homepage": "https://nlp.cs.washington.edu/triviaqa/", "license": "Apache License 2.0", "features": {"question_id": {"dtype": "string", "id": null, "_type": "Value"}, "question_source": {"dtype": "string", "id": null, "_type": "Value"}, "question": {"dtype": "string", "id": null, "_type": "Value"}, "answer": {"aliases": {"feature": {"dtype": "string", "id": null, "_type": "Value"}, "length": -1, "id": null, "_type": "Sequence"}, "value": {"dtype": "string", "id": null, "_type": "Value"}}, "search_results": {"feature": {"description": {"dtype": "string", "id": null, "_type": "Value"}, "filename": {"dtype": "string", "id": null, "_type": "Value"}, "rank": {"dtype": "int32", "id": null, "_type": "Value"}, "title": {"dtype": "string", "id": null, "_type": "Value"}, "url": {"dtype": "string", "id": null, "_type": "Value"}, "search_context": {"dtype": "string", "id": null, "_type": "Value"}}, "length": -1, "id": null, "_type": "Sequence"}}, "post_processed": null, "supervised_keys": null, "task_templates": null, "builder_name": "triviaqa", "config_name": "triviaqa", "version": {"version_str": "0.0.1", "description": null, "major": 0, "minor": 0, "patch": 1}, "splits": {"train": {"name": "train", "num_bytes": 1271393601, "num_examples": 87622, "dataset_name": "triviaqa"}, "validation": {"name": "validation", "num_bytes": 163819509, "num_examples": 11313, "dataset_name": "triviaqa"}}, "download_checksums": {"http://eaidata.bmk.sh/data/triviaqa-unfiltered.tar.gz": {"num_bytes": 546481381, "checksum": "adc19b42769062d241a8fbe834c56e58598d9322eb6c614e9f33a68a2cf5523e"}}, "download_size": 546481381, "post_processing_size": null, "dataset_size": 1435213110, "size_in_bytes": 1981694491}}
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Custom TriviaQA because HF version sanitizes the dataset differently.
# https://github.com/huggingface/datasets/blob/9977ade72191ff0b6907ec63935448c6269a91a1/datasets/trivia_qa/trivia_qa.py#L285
"""TriviaQA (Unfiltered Raw) dataset."""
import json
import os
import datasets
_CITATION = """\
@InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
title = {TriviaQA: A Large Scale Distantly Supervised Challenge Dataset for Reading Comprehension},
booktitle = {Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics},
month = {July},
year = {2017},
address = {Vancouver, Canada},
publisher = {Association for Computational Linguistics},
}
"""
_DESCRIPTION = """\
TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence
triples. TriviaQA includes 95K question-answer pairs authored by trivia enthusiasts
and independently gathered evidence documents, six per question on average, that provide
high quality distant supervision for answering the questions.
"""
_HOMEPAGE = "https://nlp.cs.washington.edu/triviaqa/"
_LICENSE = "Apache License 2.0"
_URLS = "https://nlp.cs.washington.edu/triviaqa/data/triviaqa-unfiltered.tar.gz"
class Triviaqa(datasets.GeneratorBasedBuilder):
"""TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence triples"""
VERSION = datasets.Version("0.0.2")
BUILDER_CONFIGS = [
datasets.BuilderConfig(
name="triviaqa", version=VERSION, description="The TriviaQA dataset"
),
]
def _info(self):
features = datasets.Features(
{
"question_id": datasets.Value("string"),
"question_source": datasets.Value("string"),
"question": datasets.Value("string"),
"answer": {
"aliases": datasets.features.Sequence(
datasets.Value("string"),
),
"value": datasets.Value("string"),
},
"search_results": datasets.features.Sequence(
{
"description": datasets.Value("string"),
"filename": datasets.Value("string"),
"rank": datasets.Value("int32"),
"title": datasets.Value("string"),
"url": datasets.Value("string"),
"search_context": datasets.Value("string"),
}
),
}
)
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=features,
homepage=_HOMEPAGE,
license=_LICENSE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
urls = _URLS
data_dir = dl_manager.download_and_extract(urls)
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-train.json"
),
},
),
datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
# These kwargs will be passed to _generate_examples
gen_kwargs={
"filepath": os.path.join(
data_dir, "triviaqa-unfiltered", "unfiltered-web-dev.json"
),
},
),
]
# method parameters are unpacked from `gen_kwargs` as given in `_split_generators`
def _generate_examples(self, filepath):
with open(filepath, encoding="utf-8") as f:
json_data = json.load(f)["Data"]
for key, data in enumerate(json_data):
search_results = []
for search_result in data["SearchResults"]:
search_results.append(
{
"description": search_result["Description"]
if "Description" in search_result
else "",
"filename": search_result["Filename"]
if "Filename" in search_result
else "",
"rank": search_result["Rank"]
if "Rank" in search_result
else -1,
"title": search_result["Title"]
if "Title" in search_result
else "",
"url": search_result["Url"]
if "Url" in search_result
else "",
"search_context": search_result["SearchContext"]
if "SearchContext" in search_result
else "",
}
)
yield key, {
"question_id": data["QuestionId"],
"question_source": data["QuestionSource"],
"question": data["Question"],
"answer": {
"aliases": data["Answer"]["Aliases"],
"value": data["Answer"]["Value"],
},
"search_results": search_results,
}
...@@ -20,6 +20,7 @@ def simple_evaluate( ...@@ -20,6 +20,7 @@ def simple_evaluate(
tasks=[], tasks=[],
num_fewshot=0, num_fewshot=0,
batch_size=None, batch_size=None,
max_batch_size=None,
device=None, device=None,
no_cache=False, no_cache=False,
limit=None, limit=None,
...@@ -41,8 +42,10 @@ def simple_evaluate( ...@@ -41,8 +42,10 @@ def simple_evaluate(
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int :param num_fewshot: int
Number of examples in few-shot context Number of examples in few-shot context
:param batch_size: int, optional :param batch_size: int or str, optional
Batch size for model Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional :param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool :param no_cache: bool
...@@ -71,7 +74,7 @@ def simple_evaluate( ...@@ -71,7 +74,7 @@ def simple_evaluate(
if model_args is None: if model_args is None:
model_args = "" model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string( lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
) )
elif isinstance(model, transformers.PreTrainedModel): elif isinstance(model, transformers.PreTrainedModel):
lm = HFLM( lm = HFLM(
...@@ -86,7 +89,7 @@ def simple_evaluate( ...@@ -86,7 +89,7 @@ def simple_evaluate(
lm = lm_eval.base.CachingLM( lm = lm_eval.base.CachingLM(
lm, lm,
"lm_cache/" "lm_cache/"
+ model + (model if isinstance(model, str) else model.model.config._name_or_path)
+ "_" + "_"
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-") + model_args.replace("=", "-").replace(",", "_").replace("/", "-")
+ ".db", + ".db",
...@@ -111,10 +114,11 @@ def simple_evaluate( ...@@ -111,10 +114,11 @@ def simple_evaluate(
# add info about the model and few shot config # add info about the model and few shot config
results["config"] = { results["config"] = {
"model": model, "model": (model if isinstance(model, str) else model.model.config._name_or_path),
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot, "num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else [],
"device": device, "device": device,
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
......
...@@ -17,6 +17,9 @@ def _get_dtype( ...@@ -17,6 +17,9 @@ def _get_dtype(
class HFLM(BaseLM): class HFLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -26,6 +29,7 @@ class HFLM(BaseLM): ...@@ -26,6 +29,7 @@ class HFLM(BaseLM):
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
max_length=None,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto", dtype: Optional[Union[str, torch.dtype]]="auto",
...@@ -56,8 +60,8 @@ class HFLM(BaseLM): ...@@ -56,8 +60,8 @@ class HFLM(BaseLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif isinstance(pretrained, str):
else:
# Initialize device # Initialize device
assert isinstance(device, str) assert isinstance(device, str)
device_list = set( device_list = set(
...@@ -74,10 +78,9 @@ class HFLM(BaseLM): ...@@ -74,10 +78,9 @@ class HFLM(BaseLM):
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device("cpu") else torch.device("cpu")
) )
assert isinstance(pretrained, str)
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# Initialize new model and tokenizer instances
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
...@@ -92,6 +95,8 @@ class HFLM(BaseLM): ...@@ -92,6 +95,8 @@ class HFLM(BaseLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
else:
raise TypeError('Parameter pretrained should be of type str or transformers.PreTrainedModel')
self.gpt2.eval() self.gpt2.eval()
...@@ -109,12 +114,15 @@ class HFLM(BaseLM): ...@@ -109,12 +114,15 @@ class HFLM(BaseLM):
# Validate batch_size # Validate batch_size
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
# setup for automatic batch size detection # setup for automatic batch size detection
if batch_size == "auto": if batch_size == "auto":
self.batch_size_per_gpu = batch_size self.batch_size_per_gpu = batch_size
else: else:
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
self._max_length = max_length
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...@@ -122,11 +130,18 @@ class HFLM(BaseLM): ...@@ -122,11 +130,18 @@ class HFLM(BaseLM):
@property @property
def max_length(self): def max_length(self):
try: if self._max_length: # if max length manually set, return it
return self.gpt2.config.n_ctx return self._max_length
except AttributeError: seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# gptneoconfig doesn't have n_ctx apparently for attr in seqlen_config_attrs:
return self.gpt2.config.max_position_embeddings if hasattr(self.gpt2.config, attr):
return getattr(self.gpt2.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self):
......
...@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import BatchEncoding from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils from lm_eval import utils
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
...@@ -76,6 +75,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -76,6 +75,7 @@ class HuggingFaceAutoLM(BaseLM):
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
revision: Optional[str] = "main", revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 512,
max_gen_toks: Optional[int] = 256, max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None, max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None, add_special_tokens: Optional[bool] = None,
...@@ -91,6 +91,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -91,6 +91,8 @@ class HuggingFaceAutoLM(BaseLM):
load_in_4bit: Optional[bool] = False, load_in_4bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
): ):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args: Args:
...@@ -152,6 +154,13 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -152,6 +154,13 @@ class HuggingFaceAutoLM(BaseLM):
If True, will trust the remote code when loading the model. If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False): gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference. Use Triton for GPTQ inference.
bnb_4bit_quant_type (str, optional, defaults to None):
The quantization type to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L77
bnb_4bit_compute_dtype (Union[str, torch.dtype], optional, defaults to None):
The compute dtype to use for BnB 4bit quantization. See:
https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L74
""" """
super().__init__() super().__init__()
...@@ -172,10 +181,13 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -172,10 +181,13 @@ class HuggingFaceAutoLM(BaseLM):
), "Evaluating causal models with `add_special_tokens=True` is currently not supported." ), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection # setup for automatic batch size detection
if batch_size == "auto": if str(batch_size).startswith("auto"):
self._batch_size = batch_size batch_size = batch_size.split(":")
self._batch_size = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else: else:
self._batch_size = int(batch_size) self._batch_size = int(batch_size)
self.max_batch_size = max_batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
self._max_length = max_length self._max_length = max_length
...@@ -212,6 +224,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -212,6 +224,8 @@ class HuggingFaceAutoLM(BaseLM):
gptq_use_triton=gptq_use_triton, gptq_use_triton=gptq_use_triton,
load_in_8bit=load_in_8bit, load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit, load_in_4bit=load_in_4bit,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
**model_kwargs, **model_kwargs,
) )
# note: peft_path can be different than pretrained model path # note: peft_path can be different than pretrained model path
...@@ -232,8 +246,11 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -232,8 +246,11 @@ class HuggingFaceAutoLM(BaseLM):
# the user specified one so we force `self._device` to be the same as # the user specified one so we force `self._device` to be the same as
# `lm_head`'s. # `lm_head`'s.
self._device = self.model.hf_device_map["lm_head"] self._device = self.model.hf_device_map["lm_head"]
if not use_accelerate: if not use_accelerate and not (load_in_4bit or load_in_8bit):
self.model.to(self._device) try:
self.model.to(self._device)
except:
print("Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore.")
def _create_auto_model( def _create_auto_model(
self, self,
...@@ -250,6 +267,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -250,6 +267,8 @@ class HuggingFaceAutoLM(BaseLM):
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False, gptq_use_triton: Optional[bool] = False,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
) -> transformers.AutoModel: ) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration.""" """Returns a pre-trained pytorch model from a pre-trained model configuration."""
if not quantized: if not quantized:
...@@ -258,6 +277,9 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -258,6 +277,9 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs = {} model_kwargs = {}
if transformers.__version__ >= "4.30.0": if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
model_kwargs["bnb_4bit_compute_dtype"] = getattr(torch, bnb_4bit_compute_dtype)
model = self.AUTO_MODEL_CLASS.from_pretrained( model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""), revision=revision + ("/" + subfolder if subfolder is not None else ""),
...@@ -411,19 +433,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -411,19 +433,7 @@ class HuggingFaceAutoLM(BaseLM):
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
......
...@@ -14,7 +14,6 @@ Homepage: https://github.com/hendrycks/test ...@@ -14,7 +14,6 @@ Homepage: https://github.com/hendrycks/test
""" """
from lm_eval.base import MultipleChoiceTask from lm_eval.base import MultipleChoiceTask
_CITATION = """ _CITATION = """
@article{hendryckstest2021, @article{hendryckstest2021,
title={Measuring Massive Multitask Language Understanding}, title={Measuring Massive Multitask Language Understanding},
...@@ -103,8 +102,8 @@ def create_task(subject): ...@@ -103,8 +102,8 @@ def create_task(subject):
class GeneralHendrycksTest(MultipleChoiceTask): class GeneralHendrycksTest(MultipleChoiceTask):
VERSION = 0 VERSION = 1
DATASET_PATH = "hendrycks_test" DATASET_PATH = "cais/mmlu"
DATASET_NAME = None DATASET_NAME = None
def __init__(self, subject): def __init__(self, subject):
...@@ -112,7 +111,7 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -112,7 +111,7 @@ class GeneralHendrycksTest(MultipleChoiceTask):
super().__init__() super().__init__()
def has_training_docs(self): def has_training_docs(self):
return False return True
def has_validation_docs(self): def has_validation_docs(self):
return True return True
...@@ -126,41 +125,50 @@ class GeneralHendrycksTest(MultipleChoiceTask): ...@@ -126,41 +125,50 @@ class GeneralHendrycksTest(MultipleChoiceTask):
def test_docs(self): def test_docs(self):
return map(self._process_doc, self.dataset["test"]) return map(self._process_doc, self.dataset["test"])
def _format_subject(self, subject):
words = subject.split("_")
return " ".join(words)
def fewshot_context(self, doc, num_fewshot, **kwargs):
subject = self.DATASET_NAME
description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}."
kwargs["description"] = description
return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs)
def _process_doc(self, doc): def _process_doc(self, doc):
def format_example(doc, keys): def format_example(doc, keys):
""" """
Question: <prompt> <prompt>
Choices:
A. <choice1> A. <choice1>
B. <choice2> B. <choice2>
C. <choice3> C. <choice3>
D. <choice4> D. <choice4>
Answer: Answer:
""" """
prompt = "Question: " + doc["question"] + "\nChoices:\n"
prompt += "".join( question = doc["question"].strip()
choices = "".join(
[f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])] [f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])]
) )
prompt += "Answer:" prompt = f"{question}\n{choices}Answer:"
return prompt return prompt
keys = ["A", "B", "C", "D"] keys = ["A", "B", "C", "D"]
return { return {
"query": format_example(doc, keys), "query": format_example(doc, keys),
"choices": doc["choices"], "choices": keys,
"gold": keys.index(doc["answer"]) "gold": doc["answer"],
if isinstance(doc["answer"], str)
else doc["answer"],
} }
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd):
# fewshot_examples is not just sampling from train_docs because dev is # fewshot_examples is not just sampling from train_docs because dev is
# in the same distribution as val/test but auxiliary_train isn't # in the same distribution as val/test but auxiliary_train isn't
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"])) self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"]))
return rnd.sample(list(self._fewshot_docs), k) # use the unchanged order of the dev set without sampling,
# just as in the original code https://github.com/hendrycks/test/blob/master/evaluate.py#L28
return self._fewshot_docs[:k]
def doc_to_text(self, doc): def doc_to_text(self, doc):
return doc["query"] return doc["query"]
......
...@@ -10,11 +10,10 @@ high quality distant supervision for answering the questions. ...@@ -10,11 +10,10 @@ high quality distant supervision for answering the questions.
Homepage: https://nlp.cs.washington.edu/triviaqa/ Homepage: https://nlp.cs.washington.edu/triviaqa/
""" """
import inspect import inspect
import lm_eval.datasets.triviaqa.triviaqa import string
from lm_eval.base import Task, rf from lm_eval.base import Task, rf
from lm_eval.metrics import mean from lm_eval.metrics import mean
_CITATION = """ _CITATION = """
@InProceedings{JoshiTriviaQA2017, @InProceedings{JoshiTriviaQA2017,
author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke}, author = {Joshi, Mandar and Choi, Eunsol and Weld, Daniel S. and Zettlemoyer, Luke},
...@@ -29,9 +28,9 @@ _CITATION = """ ...@@ -29,9 +28,9 @@ _CITATION = """
class TriviaQA(Task): class TriviaQA(Task):
VERSION = 1 VERSION = 2
DATASET_PATH = inspect.getfile(lm_eval.datasets.triviaqa.triviaqa) DATASET_PATH = "trivia_qa"
DATASET_NAME = None DATASET_NAME = "rc.nocontext"
def has_training_docs(self): def has_training_docs(self):
return True return True
...@@ -74,19 +73,27 @@ class TriviaQA(Task): ...@@ -74,19 +73,27 @@ class TriviaQA(Task):
return ret return ret
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
ret = [] """Uses RequestFactory to construct Requests and returns an iterable of
for alias in self._remove_prefixes(doc["answer"]["aliases"]): Requests which will be sent to the LM.
_, is_prediction = rf.loglikelihood(ctx, " " + alias) :param doc:
ret.append(is_prediction) The document as returned from training_docs, validation_docs, or test_docs.
return ret :param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
continuation = rf.greedy_until(ctx, {"until": ["\n", ".", ","]})
return continuation
def process_results(self, doc, results): def process_results(self, doc, results):
return {"acc": float(any(results))} continuation = results[0].strip().lower().translate(str.maketrans('', '', string.punctuation))
list_of_candidates = [alias.lower().translate(str.maketrans('', '', string.punctuation)) for alias in self._remove_prefixes(doc["answer"]["aliases"])]
return {"em": float(continuation in list_of_candidates)}
def aggregation(self): def aggregation(self):
return { return {
"acc": mean, "em": mean,
} }
def higher_is_better(self): def higher_is_better(self):
return {"acc": True} return {"em": True}
...@@ -8,6 +8,7 @@ import sys ...@@ -8,6 +8,7 @@ import sys
import fnmatch import fnmatch
from typing import List, Union from typing import List, Union
import gc
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -64,11 +65,11 @@ def join_iters(iters): ...@@ -64,11 +65,11 @@ def join_iters(iters):
yield from iter yield from iter
def chunks(iter, n): def chunks(iter, n=0, fn=None):
arr = [] arr = []
for x in iter: for i, x in enumerate(iter):
arr.append(x) arr.append(x)
if len(arr) == n: if len(arr) == (fn(i) if fn else n):
yield arr yield arr
arr = [] arr = []
...@@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]): ...@@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]):
raise ValueError( raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
) )
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
...@@ -16,6 +16,8 @@ def parse_args(): ...@@ -16,6 +16,8 @@ def parse_args():
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None) parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument("--max_batch_size", type=int, default=None,
help="Maximal batch size to try with --batch_size auto")
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None) parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=float, default=None, parser.add_argument("--limit", type=float, default=None,
...@@ -60,6 +62,7 @@ def main(): ...@@ -60,6 +62,7 @@ def main():
tasks=task_names, tasks=task_names,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
batch_size=args.batch_size, batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device, device=args.device,
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
...@@ -78,9 +81,10 @@ def main(): ...@@ -78,9 +81,10 @@ def main():
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(evaluator.make_table(results))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment