"src/diffusers/pipelines/allegro/pipeline_allegro.py" did not exist on "b934215d4c376ea2e08e28103443686b95ea772c"
Commit 3789d340 authored by Benjamin Fattori's avatar Benjamin Fattori Committed by lintangsutawika
Browse files

batch support for loglikelihood tokens

parent 0a3b8069
......@@ -127,7 +127,7 @@ class Seq2SeqHFLM(LM):
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens, skip_special_tokens=True)
def _model_call(self, inps, labels = None):
def _model_call(self, inps, attn_mask = None ,labels = None):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
......@@ -139,7 +139,7 @@ class Seq2SeqHFLM(LM):
logits returned from the model
"""
with torch.no_grad():
return self.model(input_ids = inps, labels = labels).logits
return self.model(input_ids = inps, attention_mask = attn_mask, labels = labels).logits
def _model_generate(self, context, max_length, stop):
......@@ -194,10 +194,11 @@ class Seq2SeqHFLM(LM):
):
inps = []
conts = []
encoder_attns = []
cont_toks_list = []
padding_length_inp = None
padding_length_cont = None
max_batch_length_inp = None
max_batch_length_cont = None
for _, context_enc, continuation_enc in chunk:
# sanity check
......@@ -217,44 +218,22 @@ class Seq2SeqHFLM(LM):
).to(self.device)
(contlen,) = cont.shape
padding_length_inp = (
padding_length_inp if padding_length_inp is not None else inplen
)
padding_length_cont = (
padding_length_cont if padding_length_cont is not None else contlen
)
inp = torch.cat(
[
inp, # [seq]
torch.zeros(padding_length_inp - inplen, dtype=torch.long).to(
inp.device
), # [padding_length - seq]
],
dim=0,
)
max_batch_length_inp = max(max_batch_length_inp, inplen) if max_batch_length_inp is not None else inplen
max_batch_length_cont = max(max_batch_length_cont, contlen) if max_batch_length_cont is not None else contlen
cont = torch.cat(
[
cont, # [seq]
torch.zeros(padding_length_cont - contlen, dtype=torch.long).to(
cont.device
), # [padding_length - seq]
],
dim=0,
)
inps.append(inp.unsqueeze(0)) # [1, padding_length]
conts.append(cont.unsqueeze(0)) # [1, padding_length]
inps.append(inp) # [1, inp_len]
conts.append(cont) # [1, cont_len]
encoder_attns.append(torch.ones_like(inp))
cont_toks_list.append(continuation_enc)
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
batched_conts = torch.cat(conts, dim=0) # [batch, padding_length]
batched_inps = utils.pad_and_concat(max_batch_length_inp, inps) # [batch, padding_length]
batched_conts = utils.pad_and_concat(max_batch_length_cont, conts) # [batch, padding_length]
batched_encoder_mask = utils.pad_and_concat(max_batch_length_inp, encoder_attns)
# need to make attention mask here too
multi_logits = F.log_softmax(
self._model_call(batched_inps, labels = batched_conts), dim=-1
self._model_call(batched_inps, attn_mask = batched_encoder_mask, labels = batched_conts), dim=-1
).cpu() # [batch, padding_length, vocab]
for (cache_key, _, _), logits, cont_toks in zip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment