Commit 1bd6229c authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

remove some old code, edge-case seq2seq case

parent 9f36ab18
......@@ -120,12 +120,16 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
# TODO: add a self.config property
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@property
def max_length(self):
......@@ -378,7 +382,8 @@ class HFLM(LM):
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
).to(self.device)
device=self.device,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor(
......@@ -387,27 +392,19 @@ class HFLM(LM):
).to(self.device)
(inplen,) = inp.shape
cont = torch.tensor(
(continuation_enc)[-self.max_length :],
(continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long,
).to(self.device)
(contlen,) = cont.shape
conts.append(cont)
padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen
padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen
# # pad length from seq to padding_length
# inp = torch.cat(
# [
# inp, # [seq]
# torch.zeros(padding_length - inplen, dtype=torch.long).to(
# inp.device
# ), # [padding_length - seq]
# ],
# dim=0,
# )
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
......@@ -415,18 +412,17 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right")
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, enc_padding_length]
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # size???
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp]
call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1
).cpu() # [batch, padding_length, vocab]
).cpu() # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
......@@ -436,7 +432,8 @@ class HFLM(LM):
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=inplen)
ctx_len = inplen if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(
0
) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment