Commit ffea4dc5 authored by Leo Gao's avatar Leo Gao
Browse files

Fix GPT2 impl partially

TODO: still need to add `until` everywhere
parent 5888a695
......@@ -17,19 +17,22 @@ class GPT2LM(LM):
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length, truncate=True):
context_tensor = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long).to(self.device)
# when too long to fit in context, truncate from the left
context_tensor = torch.tensor([self.tokenizer.encode(context.strip())[max_gen_length - 1024:]], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context_tensor,
# TODO: change to have until rather than using eos_token_id
eos_token_id=self.tokenizer.eos_token_id,
do_sample=False,
max_length=self.num_tokens(context) + max_gen_length,
)
# chop off the prompt and the final eos token
return self.tokenizer.decode(res[0][len(context[0]):-1]).strip()
return self.tokenizer.decode(res[0][min(1024 - max_gen_length, len(context_tensor[0])):-1]).strip()
def loglikelihood(self, context, continuation, truncate=True):
inp = torch.tensor([self.tokenizer.encode(context + continuation)], dtype=torch.long).to(self.device)
# when too long to fit in context, truncate from the left
inp = torch.tensor([self.tokenizer.encode(context + continuation)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(self.tokenizer.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment