Commit b48f5205 authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

more pre-commit

parent 86b71954
...@@ -102,7 +102,7 @@ class TaskConfig(dict): ...@@ -102,7 +102,7 @@ class TaskConfig(dict):
assert ( assert (
self.output_type == "greedy_until" self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!" ), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until": elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
...@@ -883,7 +883,9 @@ class ConfigurableTask(Task): ...@@ -883,7 +883,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_fn_kwargs[key] references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
...@@ -183,10 +183,8 @@ def evaluate( ...@@ -183,10 +183,8 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict( configs[task_name] = dict(task.dump_config())
task.dump_config()
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
# rnd = random.Random() # rnd = random.Random()
......
...@@ -27,6 +27,7 @@ class HFLM(LM): ...@@ -27,6 +27,7 @@ class HFLM(LM):
""" """
AUTO_MODEL_CLASS = None AUTO_MODEL_CLASS = None
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -44,7 +45,7 @@ class HFLM(LM): ...@@ -44,7 +45,7 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
if gpus <= 1: if gpus <= 1:
if device: if device:
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
...@@ -68,7 +69,7 @@ class HFLM(LM): ...@@ -68,7 +69,7 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
# get config # get config
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -77,9 +78,12 @@ class HFLM(LM): ...@@ -77,9 +78,12 @@ class HFLM(LM):
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else: else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
assert self.AUTO_MODEL_CLASS in [transformers.AutoModelForCausalLM, transformers.AutoModelForSeq2SeqLM] assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
self._model = self.AUTO_MODEL_CLASS.from_pretrained( self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
...@@ -127,7 +131,7 @@ class HFLM(LM): ...@@ -127,7 +131,7 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
@property @property
def config(self): def config(self):
# return the associated transformers.AutoConfig for the given pretrained model. # return the associated transformers.AutoConfig for the given pretrained model.
...@@ -175,20 +179,18 @@ class HFLM(LM): ...@@ -175,20 +179,18 @@ class HFLM(LM):
return self._world_size return self._world_size
def tok_encode(self, string: str, left_truncate_len=None): def tok_encode(self, string: str, left_truncate_len=None):
""" """ """
"""
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = False add_special_tokens = False
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
add_special_tokens = True add_special_tokens = True
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long # left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len: if left_truncate_len:
encoding = encoding[-left_truncate_len:] encoding = encoding[-left_truncate_len:]
return encoding return encoding
def tok_decode(self, tokens): def tok_decode(self, tokens):
...@@ -197,23 +199,9 @@ class HFLM(LM): ...@@ -197,23 +199,9 @@ class HFLM(LM):
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_call(self, inps, attn_mask=None, labels=None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: torch.Tensor :param inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call [batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional :param attn_mask: torch.Tensor, optional
...@@ -229,7 +217,9 @@ class HFLM(LM): ...@@ -229,7 +217,9 @@ class HFLM(LM):
with torch.no_grad(): with torch.no_grad():
if attn_mask or labels: if attn_mask or labels:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else: else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
return self.model(inps).logits return self.model(inps).logits
...@@ -254,16 +244,20 @@ class HFLM(LM): ...@@ -254,16 +244,20 @@ class HFLM(LM):
def _select_cont_toks(self, logits, contlen=None, inplen=None): def _select_cont_toks(self, logits, contlen=None, inplen=None):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
assert (contlen and inplen), "Must pass input len and cont. len to select scored logits for causal LM" assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert (contlen and not inplen), "Selecting scored logits for Seq2SeqLM requires only cont. len" assert (
# only discard right-padding. contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[: contlen] logits = logits[:contlen]
return logits return logits
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -289,14 +283,14 @@ class HFLM(LM): ...@@ -289,14 +283,14 @@ class HFLM(LM):
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -386,11 +380,11 @@ class HFLM(LM): ...@@ -386,11 +380,11 @@ class HFLM(LM):
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
device=self.device device=self.device,
) )
(inplen,) = inp.shape (inplen,) = inp.shape
cont = torch.tensor( cont = torch.tensor(
(continuation_enc)[-self.max_length :], (continuation_enc)[-self.max_length :],
# TODO: left-shift these? # TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type # TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long, dtype=torch.long,
...@@ -400,24 +394,43 @@ class HFLM(LM): ...@@ -400,24 +394,43 @@ class HFLM(LM):
conts.append(cont) conts.append(cont)
padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen padding_len_cont = (
max(padding_len_cont, contlen)
if padding_len_cont is not None
else contlen
)
padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length] inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
inplens.append(inplen) inplens.append(inplen)
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp] batched_inps = utils.pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont] padding_len_inp, inps
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts} batched_conts = utils.pad_and_concat(
padding_len_cont, conts
) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(
padding_len_inp, encoder_attns
) # [batch, padding_len_inp]
call_kwargs = {
"attn_mask": batched_encoder_mask,
"labels": batched_conts,
}
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
...@@ -429,13 +442,15 @@ class HFLM(LM): ...@@ -429,13 +442,15 @@ class HFLM(LM):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding) # (discard context toks if decoder-only ; discard right-padding)
ctx_len = inplen if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None ctx_len = (
inplen
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze( logits = logits.unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
...@@ -506,8 +521,8 @@ class HFLM(LM): ...@@ -506,8 +521,8 @@ class HFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
...@@ -519,4 +534,4 @@ class HFLM(LM): ...@@ -519,4 +534,4 @@ class HFLM(LM):
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
\ No newline at end of file
...@@ -19,6 +19,7 @@ from accelerate import Accelerator ...@@ -19,6 +19,7 @@ from accelerate import Accelerator
@register_model("hf-seq2seq", "seq2seq") @register_model("hf-seq2seq", "seq2seq")
class Seq2SeqHFLM(LM): class Seq2SeqHFLM(LM):
_DEFAULT_MAX_LENGTH: int = 2048 _DEFAULT_MAX_LENGTH: int = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM): ...@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@property @property
def max_length(self): def max_length(self):
return self._DEFAULT_MAX_LENGTH #TODO: Is this a good default? return self._DEFAULT_MAX_LENGTH # TODO: Is this a good default?
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return 256
...@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM): ...@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
@property @property
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=True) return self.tokenizer.encode(string, add_special_tokens=True)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, attn_mask = None ,labels = None): def _model_call(self, inps, attn_mask=None, labels=None):
""" """
inps: a torch tensor of shape [batch, sequence_ctx] inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call the size of sequence may vary from call to call
...@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM): ...@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
...@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM): ...@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id, pad_token_id=self.eot_token_id,
**generation_kwargs, **generation_kwargs,
) )
def loglikelihood(self, requests): def loglikelihood(self, requests):
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
...@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM): ...@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
...@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM): ...@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length, max_seq_len=self.max_length,
context_len=1, context_len=1,
), ),
) )
) )
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows = [(None,) + x for x in rolling_token_windows] rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0 pad_amnt = 0
...@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM): ...@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = [] res = []
...@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM): ...@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
...@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM): ...@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts = [] conts = []
encoder_attns = [] encoder_attns = []
cont_toks_list = [] cont_toks_list = []
max_batch_length_inp = None max_batch_length_inp = None
max_batch_length_cont = None max_batch_length_cont = None
...@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM): ...@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen max_batch_length_inp = (
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen max(max_batch_length_inp, inplen)
if max_batch_length_inp is not None
else inplen
)
max_batch_length_cont = (
max(max_batch_length_cont, contlen)
if max_batch_length_cont is not None
else contlen
)
inps.append(inp) # [1, inp_len] inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len] conts.append(cont) # [1, cont_len]
encoder_attns.append(torch.ones_like(inp)) encoder_attns.append(torch.ones_like(inp))
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length] batched_inps = utils.pad_and_concat(
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length] max_batch_length_inp, inps
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns) ) # [batch, padding_length]
batched_conts = utils.pad_and_concat(
max_batch_length_cont, conts
) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(
max_batch_length_inp, encoder_attns
)
# need to make attention mask here too # need to make attention mask here too
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1 self._model_call(
batched_inps, attn_mask=batched_encoder_mask, labels=batched_conts
),
dim=-1,
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip( for (cache_key, _, _), logits, cont_toks in zip(
chunk, multi_logits, cont_toks_list chunk, multi_logits, cont_toks_list
): ):
# Slice to original seq length # Slice to original seq length
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[: contlen].unsqueeze( logits = logits[:contlen].unsqueeze(0) # [1, seq, vocab]
0
) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation # Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
...@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM): ...@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
res.append(answer) res.append(answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def greedy_until(self, requests):
res = [] res = []
...@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM): ...@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
stop=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
...@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM): ...@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res.append(s) res.append(s)
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -22,7 +22,7 @@ def include_task_folder(task_dir): ...@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function Calling this function
""" """
for root, subdirs, file_list in os.walk(task_dir): for root, subdirs, file_list in os.walk(task_dir):
if (len(file_list) > 0): if len(file_list) > 0:
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
yaml_path = os.path.join(root, f) yaml_path = os.path.join(root, f)
......
...@@ -19,6 +19,11 @@ import transformers ...@@ -19,6 +19,11 @@ import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
<<<<<<< HEAD
=======
import transformers
>>>>>>> more pre-commit
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
<<<<<<< HEAD
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri ...@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
>>>>>>> more pre-commit
""" """
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
...@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri ...@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
if padding_side == "right": if padding_side == "right":
# right-pad # right-pad
tensors[i] = torch.cat( tensors[i] = torch.cat(
[ [
tensor, # [seq] tensor, # [seq]
torch.zeros( torch.zeros(
max_length - tensor_len, max_length - tensor_len,
dtype=torch.long, dtype=torch.long,
device=tensor.device, device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
else: else:
# left-pad # left-pad
tensors[i] = torch.cat( tensors[i] = torch.cat(
[ [
torch.zeros( torch.zeros(
max_length - tensor_len, max_length - tensor_len,
dtype=torch.long, dtype=torch.long,
device=tensor.device, device=tensor.device,
), # [padding_length - seq] ), # [padding_length - seq]
tensor, # [seq] tensor, # [seq]
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
else: else:
tensors[i] = tensor.unsqueeze(0) tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0) return torch.cat(tensors, dim=0)
<<<<<<< HEAD
# Multi-token stopping criteria # Multi-token stopping criteria
=======
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria): class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.""" """Criteria to stop on the specified multi-token sequence."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment