"vscode:/vscode.git/clone" did not exist on "b3f57016200141f00926b66dd605f2f6dad2f7dc"
Commit 0b4f88dd authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

make seq2seq take correct args format

parent 1a6b31a8
import torch import torch
import transformers import transformers
import copy
from tqdm import tqdm from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,8 +11,9 @@ from lm_eval.logger import eval_logger ...@@ -10,8 +11,9 @@ from lm_eval.logger import eval_logger
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from typing import List
@register_model("hf-seq2seq", "seq2seq") @register_model("hf-seq2seq", "seq2seq")
...@@ -83,6 +85,14 @@ class Seq2SeqHFLM(LM): ...@@ -83,6 +85,14 @@ class Seq2SeqHFLM(LM):
print(warning) print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -142,18 +152,30 @@ class Seq2SeqHFLM(LM): ...@@ -142,18 +152,30 @@ class Seq2SeqHFLM(LM):
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_generate(self, context, max_length, stop): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, 1, context.shape[0]
) )
if hasattr(self, "accelerator"):
return self.model.generate( return self.accelerator.unwrap_model(self.model).generate(
context, context,
max_new_tokens=max_length, max_new_tokens=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=False, pad_token_id=self.eot_token_id,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_new_tokens=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
) )
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -173,7 +195,7 @@ class Seq2SeqHFLM(LM): ...@@ -173,7 +195,7 @@ class Seq2SeqHFLM(LM):
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
...@@ -317,9 +339,30 @@ class Seq2SeqHFLM(LM): ...@@ -317,9 +339,30 @@ class Seq2SeqHFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, until in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
if isinstance(until, str): until = None
until = [until] if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
until = [gen_kwargs]
elif not isinstance(until, list):
raise ValueError(
f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `gen_kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
if "max_gen_toks" in gen_kwargs.keys():
max_gen_toks = gen_kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
(primary_until) = until[0] (primary_until) = until[0]
context_enc = torch.tensor( context_enc = torch.tensor(
...@@ -327,62 +370,17 @@ class Seq2SeqHFLM(LM): ...@@ -327,62 +370,17 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
) )
s = self.tok_decode(cont[0].tolist()) s = self.tok_decode(cont[0].tolist())
print(s)
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
print(s)
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment