Commit ba2b849c authored by Leo Gao's avatar Leo Gao
Browse files

GPT2: add device parameter

parent 5bb5e96d
...@@ -5,12 +5,13 @@ import torch.nn.functional as F ...@@ -5,12 +5,13 @@ import torch.nn.functional as F
class GPT2LM(LM): class GPT2LM(LM):
def __init__(self): def __init__(self, dev='cpu'):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2') self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(dev)
self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2') self.tok = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.dev = dev
def generate(self, context, until): def generate(self, context, until):
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long) context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.dev)
res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024) res = self.gpt2.generate(context, eos_token_id=self.tok.encoder[until], do_sample=False, max_length=1024)
# chop off the prompt and the final eos token # chop off the prompt and the final eos token
...@@ -18,7 +19,7 @@ class GPT2LM(LM): ...@@ -18,7 +19,7 @@ class GPT2LM(LM):
def loglikelihood(self, context, continuation): def loglikelihood(self, context, continuation):
print('likelihood:', context, continuation) print('likelihood:', context, continuation)
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long) inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.dev)
ctxlen = len(self.tok.encode(context.strip())) ctxlen = len(self.tok.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq] cont_toks = inp[:, ctxlen:] # [batch, seq]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment