Unverified Commit 2b8956b8 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

gpt2: Mask out all tokens above 50256

parent efbe6e7f
......@@ -61,7 +61,7 @@ class GPT2LM(LM):
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
logits = F.log_softmax(self.gpt2(inp)[0][:, :, :50257], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment