Unverified Commit ef4a26f6 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #613 from EleutherAI/add-back-cache

[Refactor] batch generation better for `hf` model ; deprecate `hf-causal` in new release
parents 4e0d0e3a c5dbf289
......@@ -9,8 +9,8 @@ We’d like your help to test it out! you can help by:
2. Porting tasks supported in the previous version of the harness to the new YAML configuration format. Please check out our [task implementation guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md) for more information.
If you choose to port a task not yet completed according to [our checklist](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/tasks/README.md), then you can contribute it by opening a PR containing [Refactor] in the name with:
- A command of the form `python main.py --model hf-causal --model_args ..... --tasks <task name> ...` which will run the task in the `master` branch, and what the score is
- A command of the form `python main.py --model hf-causal --model_args ..... --tasks <task name> ...` to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
- A command of the form `python main.py --model hf --model_args ..... --tasks <task name> ...` which will run the task in the `master` branch, and what the score is
- A command of the form `python main.py --model hf --model_args ..... --tasks <task name> ...` to run the task in your PR branch to `big-refactor`, and what the resulting score is, to show that we achieve equality between the two implementations.
Lastly, we'll no longer be accepting new feature requests beyond those that are already open to the master branch as we carry out this switch to the new version over the next week, though we will be accepting bugfixes to `master` branch and PRs to `big-refactor`. Feel free to reach out in the #lm-thunderdome channel of the EAI discord for more information.
......@@ -59,7 +59,7 @@ To evaluate a model hosted on the [HuggingFace Hub](https://huggingface.co/model
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/gpt-j-6B \
--tasks hellaswag \
--device cuda:0 \
......@@ -70,29 +70,30 @@ Additional arguments can be provided to the model constructor using the `--model
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
--tasks lambada_openai,hellaswag \
--device cuda:0 \
--batch_size 8
```
### Multi-GPU Evaluation with Hugging Face `transformers`
Models that are loaded via either `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) or `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supported via Support for this model type is currently pending.
To parallelize evaluation across multiple GPUs, we allow for launching evaluation via the `accelerate` library as follows:
### Multi-GPU Evaluation with Hugging Face `accelerate`
To parallelize evaluation of HuggingFace models across multiple GPUs, we allow for two different types of multi-GPU evaluation.
The first is performed by launching evaluation via the `accelerate` library as follows:
```
accelerate launch main.py \
--model hf-causal \
--model hf \
--tasks lambada_openai,arc_easy \
--batch_size 16 \
```
### Evaluation of Seq2Seq Models
To evaluate models that are loaded via `AutoSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface, you instead use `--model hf-seq2seq`. Support for this model type is currently pending.
This will perform *data-parallel evaluation*: that is, placing a **single full copy** of your model onto each available GPU and *splitting batches across GPUs* to evaluate on K GPUs K times faster than on one.
> **Warning**: Choosing the wrong model may result in erroneous outputs despite not erroring.
### Commercial APIs
......@@ -139,7 +140,7 @@ This will write out one text file for each task.
For models loaded with the HuggingFace `transformers` library, any arguments provided via `--model_args` get passed to the relevant constructor directly. This means that anything you can do with `AutoModel` can be done with our library. For example, you can pass a local path via `pretrained=` or use models finetuned with [PEFT](https://github.com/huggingface/peft) by taking the call you would run to evaluate the base model and add `,peft=PATH` to the `model_args` argument:
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0
......@@ -149,7 +150,7 @@ GPTQ quantized models can be loaded by specifying their file names in `,quantize
```bash
python main.py \
--model hf-causal \
--model hf \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \
--tasks hellaswag
```
......
......@@ -718,12 +718,14 @@ class ConfigurableTask(Task):
raise TypeError
def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref.
# returns a version of the gold target answer to a document,
# which should be passed into metric for scoring as the ground truth.
# in multiple_choice tasks, this should be castable to an int corresponding to the index
# within the answer choices, while doc_to_target is the string version of {{answer_choices[gold]}}.
if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias
else:
# doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str:
......@@ -772,7 +774,7 @@ class ConfigurableTask(Task):
Instance(
request_type="loglikelihood",
doc=doc,
arguments=("", " {}".format(choice)),
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
......
from . import hf_causal
from . import huggingface
from . import openai_completions
from . import anthropic_llms
from . import textsynth
from . import dummy
from . import huggingface
# TODO: implement __all__
import torch
import transformers
import copy
from tqdm import tqdm
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import Optional, Union
@register_model("hf-causal")
class HFCausalLM(LM):
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]] = "auto",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
else:
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
self._rank = 0
self._world_size = 1
else:
self._device = "cpu"
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device)
self.model.eval()
eval_logger.info(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
)
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# multigpu support with accelerate
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
try:
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).config.n_ctx
else:
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.model
).config.max_position_embeddings
else:
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
pad_amnt = 0
if self.world_size > 1:
# TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
):
inps = []
cont_toks_list = []
inplens = []
padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
(inplen,) = inp.shape
cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length = (
padding_length if padding_length is not None else inplen
)
# pad length from seq to padding_length
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(cont)
inplens.append(inplen)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
multi_logits = F.log_softmax(
self._model_call(batched_inps), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
chunk, multi_logits, inps, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[inplen - contlen : inplen].unsqueeze(
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
# TODO: implement fully general `until` that handles until that are
# multiple tokens or that span multiple tokens correctly
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
primary_until = until[0]
# try:
# (primary_until,) = self.tok_encode(until[0])
# except Exception:
# # if our primary until would be multiple tokens long, we'll have errors.
# # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
for term in until:
s = s.split(term)[0]
res.append(s)
return re_ord.get_original(res)
......@@ -3,6 +3,7 @@ import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy
from collections import defaultdict
from tqdm import tqdm
import torch.nn.functional as F
......@@ -15,6 +16,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import List, Union
@register_model("hf-auto", "hf", "huggingface")
......@@ -99,6 +101,7 @@ class HFLM(LM):
)
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self._max_length = max_length
......@@ -204,6 +207,33 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer(
strings,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
)
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding["input_ids"], encoding["attention_mask"]
def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
......@@ -491,75 +521,112 @@ class HFLM(LM):
return re_ord.get_original(res)
def greedy_until(self, requests):
res = []
res = defaultdict(list)
re_ords = {}
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
re_ord.get_reordered(),
self.batch_size,
):
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
for context, gen_kwargs in tqdm(
re_ord.get_reordered(), disable=(self.rank != 0)
):
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
# perform batched generation
cont = self._model_generate(
context=context_enc,
attention_mask=attn_masks,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)],
device=self.device,
)
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks = cont_toks[context_enc.shape[1] :]
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**kwargs,
)
s = self.tok_decode(cont_toks)
cont_toks_list = cont[0].tolist()
# discard context toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks_list = cont_toks_list[context_enc.shape[1] :]
s = self.tok_decode(cont_toks_list)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0: # ignore '' separator, for seq2seq case where
s = s.split(term)[0]
res[key].append(s)
res.append(s)
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
self.cache_hook.add_partial("greedy_until", (context, gen_kwargs), s)
pbar.close()
return re_ord.get_original(res)
return grouper.get_original(res)
......@@ -230,6 +230,64 @@ class Reorderer:
return res
class Grouper:
"""
takes an array `arr` and function `fn` and returns a dictionary
with keys fn(ob) for each ob in `arr` and with values `self.arr[key]` a list of all
objects in `arr` satisfying `key == fn(ob)`.
"""
def __init__(self, arr, fn):
# self.orig_arr = arr
self.size = len(arr)
arr = list(enumerate(arr))
def group_return_dict(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return res
arr = group_return_dict(arr, lambda x: fn(x[1]))
# self.arr has format Dict[Tuple[int, <entry from orig. arr>]]
self.arr = arr
self._grouped = None
def get_grouped(self):
# return the contents but not indices for our grouped dict.
if self._grouped:
return self._grouped
grouped = {}
for key in self.arr.keys():
# drop the index from each element of self.arr
grouped[key] = [y[1] for y in self.arr[key]]
self._grouped = grouped
return grouped
def get_original(self, grouped_dict):
# take in a grouped dictionary with e.g. results for each key listed
# in the same order as the instances in `self.arr`, and
# return the results in the same (single list) order as `self.orig_arr`.
res = [None] * self.size
cov = [False] * self.size
# orig = [None] * self.size
assert grouped_dict.keys() == self.arr.keys()
for key in grouped_dict.keys():
for (ind, _), v in zip(self.arr[key], grouped_dict[key]):
res[ind] = v
cov[ind] = True
# orig[ind] = _
assert all(cov)
# assert orig == self.orig_arr
return res
def make_table(result_dict):
"""Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter
......@@ -434,6 +492,7 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment