Commit 1bd6229c authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

remove some old code, edge-case seq2seq case

parent 9f36ab18
...@@ -120,12 +120,16 @@ class HFLM(LM): ...@@ -120,12 +120,16 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
# TODO: add a self.config property
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood? # TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@property @property
def max_length(self): def max_length(self):
...@@ -378,7 +382,8 @@ class HFLM(LM): ...@@ -378,7 +382,8 @@ class HFLM(LM):
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long, dtype=torch.long,
).to(self.device) device=self.device,
)
(inplen,) = inp.shape (inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor( inp = torch.tensor(
...@@ -387,27 +392,19 @@ class HFLM(LM): ...@@ -387,27 +392,19 @@ class HFLM(LM):
).to(self.device) ).to(self.device)
(inplen,) = inp.shape (inplen,) = inp.shape
cont = torch.tensor( cont = torch.tensor(
(continuation_enc)[-self.max_length :], (continuation_enc)[-self.max_length :],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype=torch.long, dtype=torch.long,
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
conts.append(cont)
padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen padding_len_cont = max(padding_len_cont, contlen) if padding_len_cont is not None else contlen
padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen padding_len_inp = max(padding_len_inp, inplen) if padding_len_inp is not None else inplen
# # pad length from seq to padding_length
# inp = torch.cat(
# [
# inp, # [seq]
# torch.zeros(padding_length - inplen, dtype=torch.long).to(
# inp.device
# ), # [padding_length - seq]
# ],
# dim=0,
# )
inps.append(inp) # [1, inp_length] inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
inplens.append(inplen) inplens.append(inplen)
...@@ -415,18 +412,17 @@ class HFLM(LM): ...@@ -415,18 +412,17 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# batched_inps = torch.cat(inps, dim=0) # [batch, padding_length] batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right") # [batch, padding_len_inp]
batched_inps = utils.pad_and_concat(padding_len_inp, inps, padding_side="right")
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, enc_padding_length] batched_inps = utils.pad_and_concat(padding_len_inp, inps) # [batch, padding_len_inp]
batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_length] batched_conts = utils.pad_and_concat(padding_len_cont, conts) # [batch, padding_len_cont]
batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # size??? batched_encoder_mask = utils.pad_and_concat(padding_len_inp, encoder_attns) # [batch, padding_len_inp]
call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts} call_kwargs = {"attn_mask": batched_encoder_mask, "labels": batched_conts}
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length (inp or cont), vocab]
for (cache_key, _, _), logits, inplen, cont_toks in zip( for (cache_key, _, _), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list chunk, multi_logits, inplens, cont_toks_list
...@@ -436,7 +432,8 @@ class HFLM(LM): ...@@ -436,7 +432,8 @@ class HFLM(LM):
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding) # (discard context toks if decoder-only ; discard right-padding)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=inplen) ctx_len = inplen if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM else None
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze( logits = logits.unsqueeze(
0 0
) # [1, seq, vocab] ) # [1, seq, vocab]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment