Unverified Commit b1b4793c authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #3 from EleutherAI/seq2seq-support

Test + combine Causal and Seq2Seq multi-GPU
parents a6c640d3 9cf4a104
...@@ -73,7 +73,7 @@ class TaskConfig(dict): ...@@ -73,7 +73,7 @@ class TaskConfig(dict):
repeats: int = 1 repeats: int = 1
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
delimiter: str = "\n\n" delimiter: str = "\n\n"
...@@ -95,15 +95,18 @@ class TaskConfig(dict): ...@@ -95,15 +95,18 @@ class TaskConfig(dict):
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs:
assert ( assert (
self.output_type == "greedy_until" self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!" ), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -122,6 +125,9 @@ class TaskConfig(dict): ...@@ -122,6 +125,9 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()): for k, v in list(cfg_dict.items()):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
...@@ -737,10 +743,11 @@ class ConfigurableTask(Task): ...@@ -737,10 +743,11 @@ class ConfigurableTask(Task):
def gold_alias(self, doc): def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a # TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref. # processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias: if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self._config.gold_alias
else: else:
doc_to_target = self._config.doc_to_target # doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
...@@ -842,7 +849,11 @@ class ConfigurableTask(Task): ...@@ -842,7 +849,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc)) if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
...@@ -894,7 +905,9 @@ class ConfigurableTask(Task): ...@@ -894,7 +905,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
...@@ -183,9 +183,7 @@ def evaluate( ...@@ -183,9 +183,7 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict( configs[task_name] = dict(task.dump_config())
task.dump_config()
) # TODO: don't access a private attribute here ; for non-YAML tasks handle this case
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
......
...@@ -3,5 +3,6 @@ from . import openai_completions ...@@ -3,5 +3,6 @@ from . import openai_completions
from . import textsynth from . import textsynth
from . import dummy from . import dummy
from . import seq2seq from . import seq2seq
from . import hf_merged
# TODO: implement __all__ # TODO: implement __all__
...@@ -11,12 +11,13 @@ from lm_eval.logger import eval_logger ...@@ -11,12 +11,13 @@ from lm_eval.logger import eval_logger
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice
@register_model("hf-causal") @register_model("hf-causal")
class HFLM(LM): class HFCausalLM(LM):
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -34,6 +35,7 @@ class HFLM(LM): ...@@ -34,6 +35,7 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
if gpus <= 1: if gpus <= 1:
if device: if device:
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
...@@ -84,6 +86,14 @@ class HFLM(LM): ...@@ -84,6 +86,14 @@ class HFLM(LM):
) )
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -151,27 +161,33 @@ class HFLM(LM): ...@@ -151,27 +161,33 @@ class HFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(inps)[0] return self.model(inps).logits
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys(): if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate( return self.accelerator.unwrap_model(self.model).generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
else: else:
return self.model.generate( return self.model.generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
...@@ -191,9 +207,6 @@ class HFLM(LM): ...@@ -191,9 +207,6 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
...@@ -362,6 +375,7 @@ class HFLM(LM): ...@@ -362,6 +375,7 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in gen_kwargs.keys():
...@@ -383,12 +397,13 @@ class HFLM(LM): ...@@ -383,12 +397,13 @@ class HFLM(LM):
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
try: primary_until = until[0]
(primary_until,) = self.tok_encode(until[0]) # try:
except Exception: # (primary_until,) = self.tok_encode(until[0])
# if our primary until would be multiple tokens long, we'll have errors. # except Exception:
# TODO: handling this better will let us stop generating earlier + often. # # if our primary until would be multiple tokens long, we'll have errors.
primary_until = self.eot_token_id # # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]] [self.tok_encode(context)[max_gen_toks - self.max_length :]]
...@@ -397,7 +412,7 @@ class HFLM(LM): ...@@ -397,7 +412,7 @@ class HFLM(LM):
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
......
import torch
import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
import copy
from tqdm import tqdm
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
@register_model("hf-auto")
class HFLM(LM):
"""
An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
Supports data-parallel multi-GPU with HF Accelerate.
"""
AUTO_MODEL_CLASS = None
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
else:
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
self._rank = 0
self._world_size = 1
else:
self._device = "cpu"
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
# get config
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
)
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
).to(self.device)
# forever after, access self._model through self.model property
self.model.eval()
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
revision=revision,
)
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# multigpu support with accelerate
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
try:
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.model.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
return self.batch_size_per_gpu
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None):
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
return self.tokenizer.decode(tokens)
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask=None, labels=None):
"""
:param inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
:param labels: torch.Tensor, optional
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
:return
A torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model's decoder
"""
with torch.no_grad():
if attn_mask or labels:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
return self.model.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen]
return logits
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
):
inps = []
cont_toks_list = []
inplens = []
conts = []
encoder_attns = []
padding_len_inp = None
padding_len_cont = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
cont = torch.tensor(
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
device=self.device,
)
(contlen,) = cont.shape
conts.append(cont)
padding_len_cont = (
max(padding_len_cont, contlen)
if padding_len_cont is not None
else contlen
)
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(
padding_len_inp, inps
) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
).cpu() # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
ctx_len = (
inplen
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
res = []
def _collate(x):
toks = self.tok_encode(x[0])
return len(toks), x[0]
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# first stop sequence is used to halt generation upon encountering
(primary_until) = until[0]
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length
context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)]
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist())
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
s = s.split(term)[0]
res.append(s)
return re_ord.get_original(res)
import torch import torch
import transformers import transformers
import copy
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,13 +11,15 @@ from lm_eval.logger import eval_logger ...@@ -10,13 +11,15 @@ from lm_eval.logger import eval_logger
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from typing import List
@register_model("hf-seq2seq", "seq2seq") @register_model("hf-seq2seq", "seq2seq")
class Seq2SeqHFLM(LM): class Seq2SeqHFLM(LM):
_DEFAULT_MAX_LENGTH: int = 2048 _DEFAULT_MAX_LENGTH: int = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -83,6 +86,14 @@ class Seq2SeqHFLM(LM): ...@@ -83,6 +86,14 @@ class Seq2SeqHFLM(LM):
print(warning) print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -101,7 +112,8 @@ class Seq2SeqHFLM(LM): ...@@ -101,7 +112,8 @@ class Seq2SeqHFLM(LM):
@property @property
def max_length(self): def max_length(self):
return self._DEFAULT_MAX_LENGTH #TODO: Is this a good default? return self._DEFAULT_MAX_LENGTH # TODO: Is this a good default?
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 256
...@@ -121,14 +133,14 @@ class Seq2SeqHFLM(LM): ...@@ -121,14 +133,14 @@ class Seq2SeqHFLM(LM):
@property @property
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=True) return self.tokenizer.encode(string, add_special_tokens=True)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: a torch tensor of shape [batch, sequence_ctx] inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call the size of sequence may vary from call to call
...@@ -140,22 +152,36 @@ class Seq2SeqHFLM(LM): ...@@ -140,22 +152,36 @@ class Seq2SeqHFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
def _model_generate(self, context, max_length, stop): ).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, 1, context.shape[0]
) )
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
...@@ -170,23 +196,23 @@ class Seq2SeqHFLM(LM): ...@@ -170,23 +196,23 @@ class Seq2SeqHFLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -215,7 +241,7 @@ class Seq2SeqHFLM(LM): ...@@ -215,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = [] res = []
...@@ -229,7 +255,7 @@ class Seq2SeqHFLM(LM): ...@@ -229,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
...@@ -239,7 +265,7 @@ class Seq2SeqHFLM(LM): ...@@ -239,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts = [] conts = []
encoder_attns = [] encoder_attns = []
cont_toks_list = [] cont_toks_list = []
max_batch_length_inp = None max_batch_length_inp = None
max_batch_length_cont = None max_batch_length_cont = None
...@@ -261,33 +287,48 @@ class Seq2SeqHFLM(LM): ...@@ -261,33 +287,48 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen max_batch_length_inp = (
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen max(max_batch_length_inp, inplen)
if max_batch_length_inp is not None
else inplen
)
max_batch_length_cont = (
max(max_batch_length_cont, contlen)
if max_batch_length_cont is not None
else contlen
)
inps.append(inp) # [1, inp_len] inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len] conts.append(cont) # [1, cont_len]
encoder_attns.append(torch.ones_like(inp)) encoder_attns.append(torch.ones_like(inp))
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length] max_batch_length_inp, inps
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns) ) # [batch, padding_length]
batched_conts = utils.pad_and_concat(
max_batch_length_cont, conts
) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(
max_batch_length_inp, encoder_attns
)
# need to make attention mask here too # need to make attention mask here too
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1 self._model_call(
batched_inps, attn_mask=batched_encoder_mask, labels=batched_conts
),
dim=-1,
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip( for (cache_key, _, _), logits, cont_toks in zip(
chunk, multi_logits, cont_toks_list chunk, multi_logits, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[: contlen].unsqueeze( logits = logits[:contlen].unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
...@@ -307,7 +348,7 @@ class Seq2SeqHFLM(LM): ...@@ -307,7 +348,7 @@ class Seq2SeqHFLM(LM):
res.append(answer) res.append(answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
...@@ -317,9 +358,30 @@ class Seq2SeqHFLM(LM): ...@@ -317,9 +358,30 @@ class Seq2SeqHFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(until, str): until = None
until = [until] if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until) = until[0] (primary_until) = until[0]
context_enc = torch.tensor( context_enc = torch.tensor(
...@@ -327,62 +389,16 @@ class Seq2SeqHFLM(LM): ...@@ -327,62 +389,16 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
) )
s = self.tok_decode(cont[0].tolist()) s = self.tok_decode(cont[0].tolist())
print(s)
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
print(s)
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
...@@ -23,7 +23,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -23,7 +23,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] LogiQA - [ ] LogiQA
- [ ] HellaSwag - [ ] HellaSwag
- [ ] SWAG - [ ] SWAG
- [ ] OpenBookQA - [x] OpenBookQA
- [ ] SQuADv2 - [ ] SQuADv2
- [ ] RACE - [ ] RACE
- [ ] HeadQA - [ ] HeadQA
......
...@@ -22,7 +22,7 @@ def include_task_folder(task_dir): ...@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function Calling this function
""" """
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0): if len(file_list) > 0:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
......
group:
- multiple_choice
task: openbookqa
dataset_path: openbookqa
dataset_name: main
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey.lstrip()) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "{{question_stem}}"
doc_to_target: "{{gold}}" # this will be cast to an int.
should_decontaminate: true
doc_to_decontamination_query: "{{question_stem}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: " {{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -18,10 +18,12 @@ import torch ...@@ -18,10 +18,12 @@ import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import torch
import transformers
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -415,30 +417,102 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -415,30 +417,102 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
""" def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
Method for padding a list of tensors given the maximum tensor """
length in the batch. Used for batching inputs and continuations in Method for padding a list of tensors given the maximum tensor
seq2seq models. length in the batch. Used for batching inputs and continuations in
seq2seq models.
""" """
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
tensors[i] = torch.cat( if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[ [
tensor, # [seq] tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to( torch.zeros(
tensor.device max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
tensor, # [seq]
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
else: else:
tensors[i] = tensor.unsqueeze(0) tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0) return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
...@@ -13,6 +13,12 @@ setuptools.setup( ...@@ -13,6 +13,12 @@ setuptools.setup(
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
# required to include yaml files in pip installation
package_data={
"lm_eval": ["**/*.yaml"],
"examples": ["**/*.yaml"],
},
include_package_data=True,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment