Commit 747b851d authored by Leo Gao's avatar Leo Gao
Browse files

Update GPT2LM to handle neo based models as well

parent c84a4af4
...@@ -11,10 +11,11 @@ class GPT2LM(LM): ...@@ -11,10 +11,11 @@ class GPT2LM(LM):
def __init__(self, device="cpu", pretrained='gpt2'): def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device) self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained) self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
self.max_length = self.gpt2.config.n_ctx
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373] assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
...@@ -39,8 +40,8 @@ class GPT2LM(LM): ...@@ -39,8 +40,8 @@ class GPT2LM(LM):
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device) inp = torch.tensor([(context_enc + continuation_enc)[-self.max_length:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
...@@ -63,7 +64,7 @@ class GPT2LM(LM): ...@@ -63,7 +64,7 @@ class GPT2LM(LM):
for context, until in tqdm(requests): for context, until in tqdm(requests):
if isinstance(until, str): until = [until] if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - 1024:]]).to(self.device) context_enc = torch.tensor([self.tokenizer.encode(context)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device)
primary_until, = self.tokenizer.encode(until[0]) primary_until, = self.tokenizer.encode(until[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment