"vscode:/vscode.git/clone" did not exist on "6cd94d981ea3d38841cc29474a70a7835c0f8dd6"
Unverified Commit b1b4793c authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #3 from EleutherAI/seq2seq-support

Test + combine Causal and Seq2Seq multi-GPU
parents a6c640d3 9cf4a104
...@@ -73,7 +73,7 @@ class TaskConfig(dict): ...@@ -73,7 +73,7 @@ class TaskConfig(dict):
repeats: int = 1 repeats: int = 1
metric_list: str = None metric_list: str = None
gold_alias: str = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
delimiter: str = "\n\n" delimiter: str = "\n\n"
...@@ -95,15 +95,18 @@ class TaskConfig(dict): ...@@ -95,15 +95,18 @@ class TaskConfig(dict):
self.doc_to_target = self.template_aliases + self.doc_to_target self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.doc_to_target self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs:
assert ( assert (
self.output_type == "greedy_until" self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!" ), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -122,6 +125,9 @@ class TaskConfig(dict): ...@@ -122,6 +125,9 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()): for k, v in list(cfg_dict.items()):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
...@@ -737,10 +743,11 @@ class ConfigurableTask(Task): ...@@ -737,10 +743,11 @@ class ConfigurableTask(Task):
def gold_alias(self, doc): def gold_alias(self, doc):
# TODO: reevaluate if we need this. implemented to have a # TODO: reevaluate if we need this. implemented to have a
# processed version of answer to put into gsm8k exact_match scoring as ref. # processed version of answer to put into gsm8k exact_match scoring as ref.
if self._config.gold_alias: if self._config.gold_alias is not None:
doc_to_target = self._config.gold_alias doc_to_target = self._config.gold_alias
else: else:
doc_to_target = self._config.doc_to_target # doc_to_target = self._config.doc_to_target
return self.doc_to_target(doc)
if type(doc_to_target) == str: if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc) return utils.apply_template(doc_to_target, doc)
...@@ -842,7 +849,11 @@ class ConfigurableTask(Task): ...@@ -842,7 +849,11 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
gold = int(self.doc_to_target(doc)) if self._config.gold_alias is not None:
gold = int(self.gold_alias(doc))
else:
gold = int(self.doc_to_target(doc))
pred = np.argmax(lls) pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
...@@ -894,7 +905,9 @@ class ConfigurableTask(Task): ...@@ -894,7 +905,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
...@@ -183,9 +183,7 @@ def evaluate( ...@@ -183,9 +183,7 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict( configs[task_name] = dict(task.dump_config())
task.dump_config()
) # TODO: don't access a private attribute here ; for non-YAML tasks handle this case
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
......
...@@ -3,5 +3,6 @@ from . import openai_completions ...@@ -3,5 +3,6 @@ from . import openai_completions
from . import textsynth from . import textsynth
from . import dummy from . import dummy
from . import seq2seq from . import seq2seq
from . import hf_merged
# TODO: implement __all__ # TODO: implement __all__
...@@ -11,12 +11,13 @@ from lm_eval.logger import eval_logger ...@@ -11,12 +11,13 @@ from lm_eval.logger import eval_logger
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice
@register_model("hf-causal") @register_model("hf-causal")
class HFLM(LM): class HFCausalLM(LM):
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -34,6 +35,7 @@ class HFLM(LM): ...@@ -34,6 +35,7 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
if gpus <= 1: if gpus <= 1:
if device: if device:
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
...@@ -84,6 +86,14 @@ class HFLM(LM): ...@@ -84,6 +86,14 @@ class HFLM(LM):
) )
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -151,27 +161,33 @@ class HFLM(LM): ...@@ -151,27 +161,33 @@ class HFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(inps)[0] return self.model(inps).logits
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys(): if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate( return self.accelerator.unwrap_model(self.model).generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
else: else:
return self.model.generate( return self.model.generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
...@@ -191,9 +207,6 @@ class HFLM(LM): ...@@ -191,9 +207,6 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
...@@ -362,6 +375,7 @@ class HFLM(LM): ...@@ -362,6 +375,7 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in gen_kwargs.keys():
...@@ -383,12 +397,13 @@ class HFLM(LM): ...@@ -383,12 +397,13 @@ class HFLM(LM):
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
try: primary_until = until[0]
(primary_until,) = self.tok_encode(until[0]) # try:
except Exception: # (primary_until,) = self.tok_encode(until[0])
# if our primary until would be multiple tokens long, we'll have errors. # except Exception:
# TODO: handling this better will let us stop generating earlier + often. # # if our primary until would be multiple tokens long, we'll have errors.
primary_until = self.eot_token_id # # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]] [self.tok_encode(context)[max_gen_toks - self.max_length :]]
...@@ -397,7 +412,7 @@ class HFLM(LM): ...@@ -397,7 +412,7 @@ class HFLM(LM):
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
......
This diff is collapsed.
import torch import torch
import transformers import transformers
import copy
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,13 +11,15 @@ from lm_eval.logger import eval_logger ...@@ -10,13 +11,15 @@ from lm_eval.logger import eval_logger
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from typing import List
@register_model("hf-seq2seq", "seq2seq") @register_model("hf-seq2seq", "seq2seq")
class Seq2SeqHFLM(LM): class Seq2SeqHFLM(LM):
_DEFAULT_MAX_LENGTH: int = 2048 _DEFAULT_MAX_LENGTH: int = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -83,6 +86,14 @@ class Seq2SeqHFLM(LM): ...@@ -83,6 +86,14 @@ class Seq2SeqHFLM(LM):
print(warning) print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -101,7 +112,8 @@ class Seq2SeqHFLM(LM): ...@@ -101,7 +112,8 @@ class Seq2SeqHFLM(LM):
@property @property
def max_length(self): def max_length(self):
return self._DEFAULT_MAX_LENGTH #TODO: Is this a good default? return self._DEFAULT_MAX_LENGTH # TODO: Is this a good default?
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 256
...@@ -121,14 +133,14 @@ class Seq2SeqHFLM(LM): ...@@ -121,14 +133,14 @@ class Seq2SeqHFLM(LM):
@property @property
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=True) return self.tokenizer.encode(string, add_special_tokens=True)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: a torch tensor of shape [batch, sequence_ctx] inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call the size of sequence may vary from call to call
...@@ -140,22 +152,36 @@ class Seq2SeqHFLM(LM): ...@@ -140,22 +152,36 @@ class Seq2SeqHFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
def _model_generate(self, context, max_length, stop): ).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, 1, context.shape[0]
) )
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
...@@ -170,23 +196,23 @@ class Seq2SeqHFLM(LM): ...@@ -170,23 +196,23 @@ class Seq2SeqHFLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -215,7 +241,7 @@ class Seq2SeqHFLM(LM): ...@@ -215,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = [] res = []
...@@ -229,7 +255,7 @@ class Seq2SeqHFLM(LM): ...@@ -229,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
...@@ -239,7 +265,7 @@ class Seq2SeqHFLM(LM): ...@@ -239,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts = [] conts = []
encoder_attns = [] encoder_attns = []
cont_toks_list = [] cont_toks_list = []
max_batch_length_inp = None max_batch_length_inp = None
max_batch_length_cont = None max_batch_length_cont = None
...@@ -261,33 +287,48 @@ class Seq2SeqHFLM(LM): ...@@ -261,33 +287,48 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen max_batch_length_inp = (
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen max(max_batch_length_inp, inplen)
if max_batch_length_inp is not None
else inplen
)
max_batch_length_cont = (
max(max_batch_length_cont, contlen)
if max_batch_length_cont is not None
else contlen
)
inps.append(inp) # [1, inp_len] inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len] conts.append(cont) # [1, cont_len]
encoder_attns.append(torch.ones_like(inp)) encoder_attns.append(torch.ones_like(inp))
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length] max_batch_length_inp, inps
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns) ) # [batch, padding_length]
batched_conts = utils.pad_and_concat(
max_batch_length_cont, conts
) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(
max_batch_length_inp, encoder_attns
)
# need to make attention mask here too # need to make attention mask here too
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1 self._model_call(
batched_inps, attn_mask=batched_encoder_mask, labels=batched_conts
),
dim=-1,
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip( for (cache_key, _, _), logits, cont_toks in zip(
chunk, multi_logits, cont_toks_list chunk, multi_logits, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[: contlen].unsqueeze( logits = logits[:contlen].unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
...@@ -307,7 +348,7 @@ class Seq2SeqHFLM(LM): ...@@ -307,7 +348,7 @@ class Seq2SeqHFLM(LM):
res.append(answer) res.append(answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
...@@ -317,9 +358,30 @@ class Seq2SeqHFLM(LM): ...@@ -317,9 +358,30 @@ class Seq2SeqHFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(until, str): until = None
until = [until] if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until) = until[0] (primary_until) = until[0]
context_enc = torch.tensor( context_enc = torch.tensor(
...@@ -327,62 +389,16 @@ class Seq2SeqHFLM(LM): ...@@ -327,62 +389,16 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
) )
s = self.tok_decode(cont[0].tolist()) s = self.tok_decode(cont[0].tolist())
print(s)
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
print(s)
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
...@@ -23,7 +23,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -23,7 +23,7 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] LogiQA - [ ] LogiQA
- [ ] HellaSwag - [ ] HellaSwag
- [ ] SWAG - [ ] SWAG
- [ ] OpenBookQA - [x] OpenBookQA
- [ ] SQuADv2 - [ ] SQuADv2
- [ ] RACE - [ ] RACE
- [ ] HeadQA - [ ] HeadQA
......
...@@ -22,7 +22,7 @@ def include_task_folder(task_dir): ...@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function Calling this function
""" """
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == []) and (len(file_list) > 0): if len(file_list) > 0:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
......
group:
- multiple_choice
task: openbookqa
dataset_path: openbookqa
dataset_name: main
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey.lstrip()) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "{{question_stem}}"
doc_to_target: "{{gold}}" # this will be cast to an int.
should_decontaminate: true
doc_to_decontamination_query: "{{question_stem}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: " {{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -18,10 +18,12 @@ import torch ...@@ -18,10 +18,12 @@ import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import torch
import transformers
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -415,30 +417,102 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -415,30 +417,102 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
""" def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
Method for padding a list of tensors given the maximum tensor """
length in the batch. Used for batching inputs and continuations in Method for padding a list of tensors given the maximum tensor
seq2seq models. length in the batch. Used for batching inputs and continuations in
seq2seq models.
""" """
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
tensors[i] = torch.cat( if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[ [
tensor, # [seq] tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to( torch.zeros(
tensor.device max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
tensor, # [seq]
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
else: else:
tensors[i] = tensor.unsqueeze(0) tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0) return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
...@@ -13,6 +13,12 @@ setuptools.setup( ...@@ -13,6 +13,12 @@ setuptools.setup(
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
url="https://github.com/EleutherAI/lm-evaluation-harness", url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(), packages=setuptools.find_packages(),
# required to include yaml files in pip installation
package_data={
"lm_eval": ["**/*.yaml"],
"examples": ["**/*.yaml"],
},
include_package_data=True,
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment