Commit b48f5205 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

more pre-commit

parent 86b71954
......@@ -102,7 +102,7 @@ class TaskConfig(dict):
assert (
self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
......@@ -883,7 +883,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_fn_kwargs[key]
references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
)
result_dict = {**result_dict, **_dict}
......
......@@ -183,10 +183,8 @@ def evaluate(
# get lists of each type of request
for task_name, task in task_dict.items():
versions[task_name] = task.VERSION
configs[task_name] = dict(
task.dump_config()
)
configs[task_name] = dict(task.dump_config())
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# rnd = random.Random()
......
......@@ -27,6 +27,7 @@ class HFLM(LM):
"""
AUTO_MODEL_CLASS = None
def __init__(
self,
device="cuda",
......@@ -44,7 +45,7 @@ class HFLM(LM):
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
......@@ -68,7 +69,7 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
# get config
# get config
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
......@@ -77,9 +78,12 @@ class HFLM(LM):
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM]
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
......@@ -127,7 +131,7 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
......@@ -175,20 +179,18 @@ class HFLM(LM):
return self._world_size
def tok_encode(self, string: str, left_truncate_len=None):
"""
"""
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
encoding = encoding[-left_truncate_len:]
return encoding
def tok_decode(self, tokens):
......@@ -197,23 +199,9 @@ class HFLM(LM):
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_call(self, inps, attn_mask=None, labels=None):
"""
inps: torch.Tensor
:param inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional
......@@ -229,7 +217,9 @@ class HFLM(LM):
with torch.no_grad():
if attn_mask or labels:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
return self.model(inps).logits
......@@ -254,16 +244,20 @@ class HFLM(LM):
def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert (contlen and inplen), "Must pass input len and cont. len to select scored logits for causal LM"
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert (contlen and not inplen), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
logits = logits[: contlen]
logits = logits[:contlen]
return logits
def loglikelihood(self, requests):
......@@ -289,14 +283,14 @@ class HFLM(LM):
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
......@@ -386,11 +380,11 @@ class HFLM(LM):
inp = torch.tensor(
(context_enc)[-self.max_length :],
dtype=torch.long,
device=self.device
device=self.device,
)
(inplen,) = inp.shape
cont = torch.tensor(
(continuation_enc)[-self.max_length :],
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
......@@ -400,24 +394,43 @@ class HFLM(LM):
conts.append(cont)
padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen
padding_len_cont = (
max(padding_len_cont, contlen)
if padding_len_cont is not None
else contlen
)
padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp]
batched_inps = utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp]
call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts}
batched_inps = utils.pad_and_concat(
padding_len_inp, inps
) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
......@@ -429,13 +442,15 @@ class HFLM(LM):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
ctx_len = inplen if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None
ctx_len = (
inplen
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(
0
) # [1, seq, vocab]
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
......@@ -506,8 +521,8 @@ class HFLM(LM):
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
......@@ -519,4 +534,4 @@ class HFLM(LM):
res.append(s)
return re_ord.get_original(res)
\ No newline at end of file
return re_ord.get_original(res)
......@@ -19,6 +19,7 @@ from accelerate import Accelerator
@register_model("hf-seq2seq", "seq2seq")
class Seq2SeqHFLM(LM):
_DEFAULT_MAX_LENGTH: int = 2048
def __init__(
self,
device="cuda",
......@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@property
def max_length(self):
return self._DEFAULT_MAX_LENGTH #TODO: Is this a good default?
return self._DEFAULT_MAX_LENGTH # TODO: Is this a good default?
@property
def max_gen_toks(self):
return 256
......@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=True)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None):
def _model_call(self, inps, attn_mask=None, labels=None):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
......@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
logits returned from the model
"""
with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
......@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
**generation_kwargs,
)
)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in [req.args for req in requests]:
......@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
......@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
......@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = []
......@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
......@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts = []
encoder_attns = []
cont_toks_list = []
max_batch_length_inp = None
max_batch_length_cont = None
......@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
).to(self.device)
(contlen,) = cont.shape
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen
max_batch_length_inp = (
max(max_batch_length_inp, inplen)
if max_batch_length_inp is not None
else inplen
)
max_batch_length_cont = (
max(max_batch_length_cont, contlen)
if max_batch_length_cont is not None
else contlen
)
inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len]
conts.append(cont) # [1, cont_len]
encoder_attns.append(torch.ones_like(inp))
cont_toks_list.append(continuation_enc)
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length]
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns)
batched_inps = utils.pad_and_concat(
max_batch_length_inp, inps
) # [batch, padding_length]
batched_conts = utils.pad_and_concat(
max_batch_length_cont, conts
) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(
max_batch_length_inp, encoder_attns
)
# need to make attention mask here too
multi_logits = F.log_softmax(
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1
self._model_call(
batched_inps, attn_mask=batched_encoder_mask, labels=batched_conts
),
dim=-1,
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip(
chunk, multi_logits, cont_toks_list
):
# Slice to original seq length
# Slice to original seq length
contlen = len(cont_toks)
logits = logits[: contlen].unsqueeze(
0
) # [1, seq, vocab]
logits = logits[:contlen].unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
......@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
res.append(answer)
return re_ord.get_original(res)
def greedy_until(self, requests):
res = []
......@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
).to(self.device)
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until,
**gen_kwargs,
)
......@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res.append(s)
return re_ord.get_original(res)
......@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function
"""
for root, subdirs, file_list in os.walk(task_dir):
if (len(file_list) > 0):
if len(file_list) > 0:
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
......
......@@ -19,6 +19,11 @@ import transformers
from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
<<<<<<< HEAD
=======
import transformers
>>>>>>> more pre-commit
from lm_eval.logger import eval_logger
......@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size)
<<<<<<< HEAD
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
......@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
>>>>>>> more pre-commit
"""
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
......@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0)
return torch.cat(tensors, dim=0)
<<<<<<< HEAD
# Multi-token stopping criteria
=======
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment