Commit c2aaa501 authored by Jason Phang's avatar Jason Phang
Browse files

combine gpt2

parent 2d4b3a8c
import transformers
from base import LM
import torch
import torch.nn.functional as F
class GPT2LM(LM):
def __init__(self, dev='cpu'):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(dev)
self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.dev = dev
def generate(self, context, until):
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.dev)
res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)
# chop off the prompt and the final eos token
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation):
print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.dev)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
import transformers import transformers
import torch import torch
import torch.nn.functional as F
from ..base import LM from ..base import LM
from .. import utils
from . import MODEL_REGISTRY from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt2") @MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM): class GPT2LM(LM):
def __init__(self): def __init__(self, device="cpu"):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2') self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.device = device
@classmethod
def create_from_args(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length): def generate(self, context, max_gen_length):
context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long) context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.device)
res = self.gpt2.generate( res = self.gpt2.generate(
context, context,
eos_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id,
...@@ -23,4 +31,10 @@ class GPT2LM(LM): ...@@ -23,4 +31,10 @@ class GPT2LM(LM):
return self.tok.decode(res[0][len(context[0]):-1]).strip() return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
pass inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.device)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment