Unverified Commit fbd712f7 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge branch 'big-refactor' into misc-cleanup-refactor

parents 909f7346 ffff47c6
......@@ -9,8 +9,8 @@ We’d like your help to test it out! you can help by:
2. Porting tasks supported in the previous version of the harness to the new YAML configuration format. Please check out our [task implementation guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md) for more information.
If you choose to port a task not yet completed according to [our checklist](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/README.md), then you can contribute it by opening a PR containing [Refactor] in the name with:
- A command of the form `python main.py --model hf-causal --model_args ..... --tasks <task name> ...` which will run the task in the `master` branch, and what the score is
- A command of the form `python main.py --model hf-causal --model_args ..... --tasks <task name> ...` to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
- A command of the form `python main.py --model hf --model_args ..... --tasks <task name> ...` which will run the task in the `master` branch, and what the score is
- A command of the form `python main.py --model hf --model_args ..... --tasks <task name> ...` to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
Lastly, we'll no longer be accepting new feature requests beyond those that are already open to the master branch as we carry out this switch to the new version over the next week, though we will be accepting bugfixes to `master` branch and PRs to `big-refactor`. Feel free to reach out in the #lm-thunderdome channel of the EAI discord for more information.
......@@ -59,7 +59,7 @@ To evaluate a model hosted on the [HuggingFace Hub](https://huggingface.co/model
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/gpt-j-6B \
--tasks hellaswag \
--device cuda:0 \
......@@ -70,29 +70,30 @@ Additional arguments can be provided to the model constructor using the `--model
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
--tasks lambada_openai,hellaswag \
--device cuda:0 \
--batch_size 8
```
### Multi-GPU Evaluation with Hugging Face `transformers`
Models that are loaded via either `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) or `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supported via Support for this model type is currently pending.
To parallelize evaluation across multiple GPUs, we allow for launching evaluation via the `accelerate` library as follows:
### Multi-GPU Evaluation with Hugging Face `accelerate`
To parallelize evaluation of HuggingFace models across multiple GPUs, we allow for two different types of multi-GPU evaluation.
The first is performed by launching evaluation via the `accelerate` library as follows:
```
accelerate launch main.py \
--model hf-causal \
--model hf \
--tasks lambada_openai,arc_easy \
--batch_size 16 \
```
### Evaluation of Seq2Seq Models
To evaluate models that are loaded via `AutoSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface, you instead use `--model hf-seq2seq`. Support for this model type is currently pending.
This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one.
> **Warning**: Choosing the wrong model may result in erroneous outputs despite not erroring.
### Commercial APIs
......@@ -139,7 +140,7 @@ This will write out one text file for each task.
For models loaded with the HuggingFace `transformers` library, any arguments provided via `--model_args` get passed to the relevant constructor directly. This means that anything you can do with `AutoModel` can be done with our library. For example, you can pass a local path via `pretrained=` or use models finetuned with [PEFT](https://github.com/huggingface/peft) by taking the call you would run to evaluate the base model and add `,peft=PATH` to the `model_args` argument:
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0
......@@ -149,7 +150,7 @@ GPTQ quantized models can be loaded by specifying their file names in `,quantize
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \
--tasks hellaswag
```
......
import abc
import os
from typing import Union
from sqlitedict import SqliteDict
import json
import hashlib
from tqdm import tqdm
from lm_eval import utils
from lm_eval.logger import eval_logger
class LM(abc.ABC):
......@@ -12,6 +19,7 @@ class LM(abc.ABC):
(inputs/outputs should be tokenization-agnostic.)
"""
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests):
......@@ -118,3 +126,104 @@ class LM(abc.ABC):
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 1
def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook
### SQLite-based caching of LM responses
def hash_args(attr, args):
dat = json.dumps([attr] + list(args))
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
class CacheHook:
def __init__(self, cachinglm):
if cachinglm is None:
self.dbdict = None
return
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res):
if self.dbdict is None:
return
hsh = hash_args(attr, req)
self.dbdict[hsh] = res
class CachingLM:
def __init__(self, lm, cache_db):
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
Underlying LM
:param cache_db: str
Path to cache db
"""
self.lm = lm
self.cache_db = cache_db
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
self.dbdict = SqliteDict(cache_db, autocommit=True)
# add hook to lm
lm.set_cache_hook(self.get_cache_hook())
def __getattr__(self, attr):
lm_attr = getattr(self.lm, attr)
if not callable(lm_attr):
return lm_attr
def fn(requests):
res = []
remaining_reqs = []
warned = False
# figure out which ones are cached and which ones are new
eval_logger.info(
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
)
for req in tqdm(requests):
hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned:
eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
)
warned = True
res.append(None)
remaining_reqs.append(req)
elif hsh in self.dbdict:
ob = self.dbdict[hsh]
assert ob is not None
res.append(ob)
else:
res.append(None)
remaining_reqs.append(req)
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
for req, r in zip(remaining_reqs, rem_res):
while res[resptr] is not None:
resptr += 1
res[resptr] = r
# caching
hsh = hash_args(attr, req.args)
self.dbdict[hsh] = r
self.dbdict.commit()
return res
return fn
def get_cache_hook(self):
return CacheHook(self)
......@@ -52,7 +52,6 @@ class TaskConfig(dict):
task: str = None
group: Union[str, list] = None
reference: str = None
dataset_path: str = None
dataset_name: str = None
......@@ -67,6 +66,8 @@ class TaskConfig(dict):
doc_to_target: Union[Callable, str] = None
use_prompt: str = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
num_fewshot: int = 0
batch_size: int = 1
......@@ -76,8 +77,6 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until"
generation_kwargs: dict = None
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
......@@ -343,6 +342,7 @@ class Task(abc.ABC):
fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random()
)
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(
doc=doc,
......@@ -719,12 +719,14 @@ class ConfigurableTask(Task):
raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
# returns a version of the gold target answer to a document,
# which should be passed into metric for scoring as the ground truth.
# in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias
else:
# doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str:
......
......@@ -39,7 +39,7 @@ def simple_evaluate(
batch_size=None,
max_batch_size=None,
device=None,
no_cache=False,
use_cache=None,
limit=None,
bootstrap_iters=100000,
check_integrity=False,
......@@ -64,8 +64,8 @@ def simple_evaluate(
Maximal batch size to try with automatic batch size detection
:param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool
Whether or not to cache
:param use_cache: str, optional
A path to a sqlite db file for caching model responses. `None` if not caching.
:param limit: int or float, optional
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
:param bootstrap_iters:
......@@ -99,6 +99,16 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM)
lm = model
if use_cache is not None:
print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
lm = lm_eval.api.model.CachingLM(
lm,
use_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db",
)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity:
......@@ -127,7 +137,7 @@ def simple_evaluate(
if hasattr(lm, "batch_sizes")
else [],
"device": device,
"no_cache": no_cache,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
......
from . import hf_causal
from . import huggingface
from . import openai_completions
from . import anthropic_llms
from . import textsynth
from . import dummy
from . import huggingface
# TODO: implement __all__
......@@ -88,6 +88,8 @@ class AnthropicLM(LM):
if not requests:
return []
requests = [req.args for req in requests]
res = []
for request in tqdm(requests):
inp = request[0]
......@@ -102,6 +104,9 @@ class AnthropicLM(LM):
stop=until,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
return res
def _model_call(self, inps):
......
import torch
import transformers
import copy
from tqdm import tqdm
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import Optional, Union
@register_model("hf-causal")
class HFCausalLM(LM):
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]] = "auto",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
else:
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
self._rank = 0
self._world_size = 1
else:
self._device = "cpu"
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device)
self.model.eval()
eval_logger.info(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
)
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# multigpu support with accelerate
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
try:
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.model
).config.max_position_embeddings
else:
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
pad_amnt = 0
if self.world_size > 1:
# TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
):
inps = []
cont_toks_list = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
(inplen,) = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
primary_until = until[0]
# try:
# (primary_until,) = self.tok_encode(until[0])
# except Exception:
# # if our primary until would be multiple tokens long, we'll have errors.
# # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
res.append(s)
return re_ord.get_original(res)
......@@ -3,6 +3,7 @@ import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy
from collections import defaultdict
from tqdm import tqdm
import torch.nn.functional as F
......@@ -15,6 +16,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import List, Union
@register_model("hf-auto", "hf", "huggingface")
......@@ -99,6 +101,7 @@ class HFLM(LM):
)
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self._max_length = max_length
......@@ -204,6 +207,33 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer(
strings,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
......@@ -486,74 +516,117 @@ class HFLM(LM):
res.append(answer)
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
res = []
res = defaultdict(list)
re_ords = {}
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
re_ord.get_reordered(),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)],
device=self.device,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont_toks)
cont_toks_list = cont[0].tolist()
# discard context toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks_list = cont_toks_list[context_enc.shape[1] :]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
s = self.tok_decode(cont_toks_list)
res[key].append(s)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
res.append(s)
pbar.close()
return re_ord.get_original(res)
return grouper.get_original(res)
......@@ -194,7 +194,7 @@ class OpenaiCompletionsLM(LM):
yield ret, lastuntil
# todo: more intelligent batching for heterogeneous `until`
for chunk, until in tqdm(
for chunk, request_args in tqdm(
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
):
inps = []
......@@ -203,6 +203,13 @@ class OpenaiCompletionsLM(LM):
inp = context_enc[-(self.max_length - self.max_gen_toks) :]
inps.append(inp)
try:
until = request_args["until"][
0
] # TODO: does this handle a list of stop seqs correctly?
except KeyError:
until = "<|endoftext|>"
response = oa_completion(
engine=self.engine,
prompt=inps,
......@@ -212,14 +219,19 @@ class OpenaiCompletionsLM(LM):
stop=until,
)
for resp, (context, until_) in zip(response.choices, chunk):
for resp, (context, args_) in zip(response.choices, chunk):
s = resp["text"]
until_ = args_.get(["until"], [])
for term in until_:
s = s.split(term)[0]
if len(term) > 0:
s = s.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until_), s)
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
)
res.append(s)
......
......@@ -101,6 +101,10 @@ class TextSynthLM(LM):
logprob = resp["logprob"]
is_greedy = resp["is_greedy"]
res.append((logprob, is_greedy))
self.cache_hook.add_partial(
"loglikelihood", (context, continuation), (logprob, is_greedy)
)
else:
logger.error(
f"The following response does not contain `logprobs`. Got:\n{resp}"
......@@ -141,6 +145,8 @@ class TextSynthLM(LM):
if "text" in resp:
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......
# v1.0 Tasks
This list keeps track of which tasks' implementations have been ported to YAML / v2.0 of the Eval Harness.
Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation.
Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. (WIP) Denotes that there exists a PR or person working on this task already.
- [ ] Glue (WIP)
- [x] SuperGlue
......@@ -14,23 +14,23 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] PiQA
- [ ] PROST
- [ ] MCTACO
- [ ] Pubmed QA
- [ ] Pubmed QA (WIP)
- [x] SciQ
- [ ] QASPER
- [ ] QA4MRE
- [ ] TriviaQA
- [x] AI2 ARC
- [ ] LogiQA
- [ ] HellaSwag
- [ ] SWAG
- [x] HellaSwag
- [ ] SWAG (WIP)
- [x] OpenBookQA
- [ ] SQuADv2
- [ ] RACE
- [ ] RACE (WIP)
- [ ] HeadQA
- [ ] MathQA
- [ ] WebQs
- [ ] WSC273
- [ ] Winogrande
- [ ] Winogrande (WIP)
- [x] ANLI
- [ ] Hendrycks Ethics
- [ ] TruthfulQA
......@@ -38,7 +38,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] Hendrycks Math
- [ ] Asdiv
- [ ] GSM8k
- [ ] Arithmetic
- [ ] Arithmetic (WIP)
- [ ] MMMLU
- [ ] Translation (WMT) suite
- [ ] Unscramble
......
......@@ -201,6 +201,64 @@ class Reorderer:
return res
class Grouper:
"""
takes an array `arr` and function `fn` and returns a dictionary
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn):
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
def group_return_dict(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return res
arr = group_return_dict(arr, lambda x: fn(x[1]))
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
self.arr = arr
self._grouped = None
def get_grouped(self):
# return the contents but not indices for our grouped dict.
if self._grouped:
return self._grouped
grouped = {}
for key in self.arr.keys():
# drop the index from each element of self.arr
grouped[key] = [y[1] for y in self.arr[key]]
self._grouped = grouped
return grouped
def get_original(self, grouped_dict):
# take in a grouped dictionary with e.g. results for each key listed
# in the same order as the instances in `self.arr`, and
# return the results in the same (single list) order as `self.orig_arr`.
res = [None] * self.size
cov = [False] * self.size
# orig = [None] * self.size
assert grouped_dict.keys() == self.arr.keys()
for key in grouped_dict.keys():
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
res[ind] = v
cov[ind] = True
# orig[ind] = _
assert all(cov)
# assert orig == self.orig_arr
return res
def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -406,6 +464,7 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
......
......@@ -39,7 +39,7 @@ def parse_args():
"If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--data_sampling", type=float, default=None)
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_cache", type=str, default=None)
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
......@@ -85,7 +85,7 @@ def main():
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
use_cache=args.use_cache,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment