Commit e779b52c authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

make seq2seq take correct args format

parent 53ff3671
import torch
import transformers
import copy
from tqdm import tqdm
import torch.nn.functional as F
......@@ -10,8 +11,9 @@ from lm_eval.logger import eval_logger
from lm_eval.api.registry import register_model
from lm_eval.api.model import LM
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import List
@register_model("hf-seq2seq", "seq2seq")
......@@ -83,6 +85,14 @@ class Seq2SeqHFLM(LM):
print(warning)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......@@ -142,18 +152,30 @@ class Seq2SeqHFLM(LM):
with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_generate(self, context, max_length, stop):
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
def loglikelihood(self, requests):
......@@ -173,7 +195,7 @@ class Seq2SeqHFLM(LM):
def loglikelihood_rolling(self, requests):
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]):
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
......@@ -317,9 +339,30 @@ class Seq2SeqHFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()):
if isinstance(until, str):
until = [until]
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until) = until[0]
context_enc = torch.tensor(
......@@ -327,62 +370,17 @@ class Seq2SeqHFLM(LM):
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist())
print(s)
for term in until:
s = s.split(term)[0]
print(s)
res.append(s)
return re_ord.get_original(res)
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment