Commit fe0311b6 authored by Leo Gao's avatar Leo Gao
Browse files

Fix problems

parent be3a6a2d
...@@ -79,7 +79,6 @@ class GPT2LM(LM): ...@@ -79,7 +79,6 @@ class GPT2LM(LM):
for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size): for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
inps = [] inps = []
ctxlens = [] ctxlens = []
inplens = []
padding_length = None padding_length = None
for _, context_enc, continuation_enc in chunk: for _, context_enc, continuation_enc in chunk:
...@@ -94,26 +93,26 @@ class GPT2LM(LM): ...@@ -94,26 +93,26 @@ class GPT2LM(LM):
# pad to length # pad to length
inp = torch.cat([ inp = torch.cat([
inp, # [seq] inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long) # [padding_length - seq] torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0) ], dim=0)
inps.append(inp) inps.append(inp.unsqueeze(0))
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
inplens.append(inplen)
multi_logits = F.log_softmax(self.gpt2(torch.stack(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab] multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
for (cache_key, _, _), logits, ctxlen, inplens in zip(chunk, multi_logits, ctxlens, inplens): for (cache_key, _, _), logits, ctxlen, inp in zip(chunk, multi_logits, ctxlens, inps):
logits = logits[ctxlen - 1:inplen - 1] # [seq, vocab] _, inplen = inp.shape
logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist() last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment