Commit 8966289a authored by Leo Gao's avatar Leo Gao
Browse files

Add GPT2 greedy_until truncation

parent dbdba695
...@@ -60,7 +60,7 @@ class GPT2LM(LM): ...@@ -60,7 +60,7 @@ class GPT2LM(LM):
for context, until in tqdm(requests): for context, until in tqdm(requests):
if isinstance(until, str): until = [until] if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device) context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0]) primary_until, = self.tokenizer.encode(until[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment