Commit ba2b849c authored by Leo Gao's avatar Leo Gao
Browse files

GPT2: add device parameter

parent 5bb5e96d
......@@ -5,12 +5,13 @@ import torch.nn.functional as F
class GPT2LM(LM):
def __init__(self):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
def __init__(self, dev='cpu'):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(dev)
self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.dev = dev
def generate(self, context, until):
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long)
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.dev)
res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)
# chop off the prompt and the final eos token
......@@ -18,7 +19,7 @@ class GPT2LM(LM):
def loglikelihood(self, context, continuation):
print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.dev)
ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment