Commit 31c29e3b authored by Leo Gao's avatar Leo Gao
Browse files

Actually fix problems and make tests pass

parent fe0311b6
...@@ -78,6 +78,7 @@ class GPT2LM(LM): ...@@ -78,6 +78,7 @@ class GPT2LM(LM):
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size): for chunk in utils.chunks(tqdm(reord.get_reordered()), self.batch_size):
inps = [] inps = []
inplens = []
ctxlens = [] ctxlens = []
padding_length = None padding_length = None
...@@ -97,17 +98,16 @@ class GPT2LM(LM): ...@@ -97,17 +98,16 @@ class GPT2LM(LM):
], dim=0) ], dim=0)
inps.append(inp.unsqueeze(0)) inps.append(inp.unsqueeze(0))
inplens.append(inplen)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab] multi_logits = F.log_softmax(self.gpt2(torch.cat(inps, dim=0))[0][:, :, :50257], dim=-1) # [batch, seq, vocab]
for (cache_key, _, _), logits, ctxlen, inp in zip(chunk, multi_logits, ctxlens, inps): for (cache_key, _, _), logits, ctxlen, inp, inplen in zip(chunk, multi_logits, ctxlens, inps, inplens):
_, inplen = inp.shape
logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab] logits = logits[ctxlen - 1:inplen - 1].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = inp[:, ctxlen:inplen] # [1, seq]
cont_toks = inp[:, ctxlen:] # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
last_token_slice = logits[:, -1, :].squeeze(0).tolist() last_token_slice = logits[:, -1, :].squeeze(0).tolist()
......
...@@ -34,4 +34,4 @@ def test_gpt2(): ...@@ -34,4 +34,4 @@ def test_gpt2():
targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281] targets = [-61.60536193847656, -56.57843780517578, -62.131004333496094, -9.799489974975586, -153.96334838867188, -341.222900390625, -731.1475830078125, -61.60536193847656, -8.682319641113281]
for (pred, _), tgt in zip(vals, targets): for (pred, _), tgt in zip(vals, targets):
assert pred == pytest.approx(tgt) assert pred == pytest.approx(tgt, abs=1e-3)
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment