Commit 9cf4a104 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

more pre-commit

parent 306cfada
...@@ -905,7 +905,9 @@ class ConfigurableTask(Task): ...@@ -905,7 +905,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_fn_kwargs[key] references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
...@@ -183,9 +183,7 @@ def evaluate( ...@@ -183,9 +183,7 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict( configs[task_name] = dict(task.dump_config())
task.dump_config()
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
......
...@@ -27,6 +27,7 @@ class HFLM(LM): ...@@ -27,6 +27,7 @@ class HFLM(LM):
""" """
AUTO_MODEL_CLASS = None AUTO_MODEL_CLASS = None
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -79,7 +80,10 @@ class HFLM(LM): ...@@ -79,7 +80,10 @@ class HFLM(LM):
else: else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM] assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
...@@ -175,9 +179,7 @@ class HFLM(LM): ...@@ -175,9 +179,7 @@ class HFLM(LM):
return self._world_size return self._world_size
def tok_encode(self, string: str, left_truncate_len=None): def tok_encode(self, string: str, left_truncate_len=None):
""" """ """
"""
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
...@@ -197,23 +199,9 @@ class HFLM(LM): ...@@ -197,23 +199,9 @@ class HFLM(LM):
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: torch.Tensor :param inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call [batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional :param attn_mask: torch.Tensor, optional
...@@ -229,7 +217,9 @@ class HFLM(LM): ...@@ -229,7 +217,9 @@ class HFLM(LM):
with torch.no_grad(): with torch.no_grad():
if attn_mask or labels: if attn_mask or labels:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else: else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
return self.model(inps).logits return self.model(inps).logits
...@@ -254,15 +244,19 @@ class HFLM(LM): ...@@ -254,15 +244,19 @@ class HFLM(LM):
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert (contlen and inplen), "Must pass input len and cont. len to select scored logits for causal LM" assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert (contlen and not inplen), "Selecting scored logits for Seq2SeqLM requires only cont. len" assert (
contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding. # only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[: contlen] logits = logits[:contlen]
return logits return logits
...@@ -296,7 +290,7 @@ class HFLM(LM): ...@@ -296,7 +290,7 @@ class HFLM(LM):
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -386,7 +380,7 @@ class HFLM(LM): ...@@ -386,7 +380,7 @@ class HFLM(LM):
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
device=self.device device=self.device,
) )
(inplen,) = inp.shape (inplen,) = inp.shape
cont = torch.tensor( cont = torch.tensor(
...@@ -400,9 +394,17 @@ class HFLM(LM): ...@@ -400,9 +394,17 @@ class HFLM(LM):
conts.append(cont) conts.append(cont)
padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen padding_len_cont = (
max(padding_len_cont, contlen)
if padding_len_cont is not None
else contlen
)
padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length] inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
...@@ -411,13 +413,24 @@ class HFLM(LM): ...@@ -411,13 +413,24 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp] batched_inps = utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont] padding_len_inp, inps
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts} batched_conts = utils.pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
...@@ -431,11 +444,13 @@ class HFLM(LM): ...@@ -431,11 +444,13 @@ class HFLM(LM):
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding) # (discard context toks if decoder-only ; discard right-padding)
ctx_len = inplen if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ctx_len = (
inplen
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze( logits = logits.unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
......
...@@ -19,6 +19,7 @@ from accelerate import Accelerator ...@@ -19,6 +19,7 @@ from accelerate import Accelerator
@register_model("hf-seq2seq", "seq2seq") @register_model("hf-seq2seq", "seq2seq")
class Seq2SeqHFLM(LM): class Seq2SeqHFLM(LM):
_DEFAULT_MAX_LENGTH: int = 2048 _DEFAULT_MAX_LENGTH: int = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM): ...@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@property @property
def max_length(self): def max_length(self):
return self._DEFAULT_MAX_LENGTH #TODO: Is this a good default? return self._DEFAULT_MAX_LENGTH # TODO: Is this a good default?
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 256
...@@ -138,7 +140,7 @@ class Seq2SeqHFLM(LM): ...@@ -138,7 +140,7 @@ class Seq2SeqHFLM(LM):
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: a torch tensor of shape [batch, sequence_ctx] inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call the size of sequence may vary from call to call
...@@ -150,7 +152,9 @@ class Seq2SeqHFLM(LM): ...@@ -150,7 +152,9 @@ class Seq2SeqHFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
...@@ -208,7 +212,7 @@ class Seq2SeqHFLM(LM): ...@@ -208,7 +212,7 @@ class Seq2SeqHFLM(LM):
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -283,8 +287,16 @@ class Seq2SeqHFLM(LM): ...@@ -283,8 +287,16 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen max_batch_length_inp = (
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen max(max_batch_length_inp, inplen)
if max_batch_length_inp is not None
else inplen
)
max_batch_length_cont = (
max(max_batch_length_cont, contlen)
if max_batch_length_cont is not None
else contlen
)
inps.append(inp) # [1, inp_len] inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len] conts.append(cont) # [1, cont_len]
...@@ -292,13 +304,22 @@ class Seq2SeqHFLM(LM): ...@@ -292,13 +304,22 @@ class Seq2SeqHFLM(LM):
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length] max_batch_length_inp, inps
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns) ) # [batch, padding_length]
batched_conts = utils.pad_and_concat(
max_batch_length_cont, conts
) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(
max_batch_length_inp, encoder_attns
)
# need to make attention mask here too # need to make attention mask here too
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1 self._model_call(
batched_inps, attn_mask=batched_encoder_mask, labels=batched_conts
),
dim=-1,
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip( for (cache_key, _, _), logits, cont_toks in zip(
...@@ -307,9 +328,7 @@ class Seq2SeqHFLM(LM): ...@@ -307,9 +328,7 @@ class Seq2SeqHFLM(LM):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[: contlen].unsqueeze( logits = logits[:contlen].unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
...@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM): ...@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -22,7 +22,7 @@ def include_task_folder(task_dir): ...@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function Calling this function
""" """
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (len(file_list) > 0): if len(file_list) > 0:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
......
...@@ -18,11 +18,12 @@ import torch ...@@ -18,11 +18,12 @@ import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import torch
import transformers import transformers
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -416,13 +417,16 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -416,13 +417,16 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
...@@ -456,7 +460,7 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri ...@@ -456,7 +460,7 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
else: else:
tensors[i] = tensor.unsqueeze(0) tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0) return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment