Commit eb7b9095 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

batch support for loglikelihood tokens

parent 226063ce
...@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM): ...@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM):
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True) return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, labels = None): def _model_call(self, inps, attn_mask = None ,labels = None):
""" """
inps: a torch tensor of shape [batch, sequence_ctx] inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call the size of sequence may vary from call to call
...@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM): ...@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(input_ids = inps, labels = labels).logits return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_generate(self, context, max_length, stop): def _model_generate(self, context, max_length, stop):
...@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM): ...@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM):
): ):
inps = [] inps = []
conts = [] conts = []
encoder_attns = []
cont_toks_list = [] cont_toks_list = []
padding_length_inp = None max_batch_length_inp = None
padding_length_cont = None max_batch_length_cont = None
for _, context_enc, continuation_enc in chunk: for _, context_enc, continuation_enc in chunk:
# sanity check # sanity check
...@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM): ...@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM):
).to(self.device) ).to(self.device)
(contlen,) = cont.shape (contlen,) = cont.shape
padding_length_inp = ( max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen
padding_length_inp if padding_length_inp is not None else inplen max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen
)
padding_length_cont = (
padding_length_cont if padding_length_cont is not None else contlen
)
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length_inp - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
cont = torch.cat( inps.append(inp) # [1, inp_len]
[ conts.append(cont) # [1, cont_len]
cont, # [seq] encoder_attns.append(torch.ones_like(inp))
torch.zeros(padding_length_cont - contlen, dtype=torch.long).to(
cont.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
conts.append(cont.unsqueeze(0)) # [1, padding_length]
cont_toks_list.append(continuation_enc) cont_toks_list.append(continuation_enc)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length] batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length]
batched_conts = torch.cat(conts, dim=0) # [batch, padding_length] batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns)
# need to make attention mask here too
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, labels = batched_conts), dim=-1 self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1
).cpu() # [batch, padding_length, vocab] ).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip( for (cache_key, _, _), logits, cont_toks in zip(
......
...@@ -14,7 +14,7 @@ from typing import List ...@@ -14,7 +14,7 @@ from typing import List
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import torch
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -327,3 +327,26 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -327,3 +327,26 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
among ranks in multigpu setting or only pulling a sample of documents among ranks in multigpu setting or only pulling a sample of documents
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim = 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment