"torchvision/csrc/cuda/ps_roi_align_kernel.cu" did not exist on "2e1e0b63145a54746d3b47c5267cc2521f113a9f"
Commit c2aaa501 authored by Jason Phang's avatar Jason Phang
Browse files

combine gpt2

parent 2d4b3a8c
import transformers
from base import LM
import torch
import torch.nn.functional as F
class GPT2LM(LM):
def __init__(self, dev='cpu'):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(dev)
self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.dev = dev
def generate(self, context, until):
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.dev)
res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)
# chop off the prompt and the final eos token
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation):
print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.dev)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
import transformers
import torch
import torch.nn.functional as F
from ..base import LM
from .. import utils
from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM):
def __init__(self):
def __init__(self, device="cpu"):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.device = device
@classmethod
def create_from_args(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length):
context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long)
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context,
eos_token_id=self.tokenizer.eos_token_id,
......@@ -23,4 +31,10 @@ class GPT2LM(LM):
return self.tok.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation):
pass
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.device)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment