Commit 747b851d authored by Leo Gao's avatar Leo Gao
Browse files

Update GPT2LM to handle neo based models as well

parent c84a4af4
......@@ -11,10 +11,11 @@ class GPT2LM(LM):
def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>"
self.max_length = self.gpt2.config.n_ctx
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
......@@ -39,8 +40,8 @@ class GPT2LM(LM):
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
......@@ -63,7 +64,7 @@ class GPT2LM(LM):
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device)
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment