"vscode:/vscode.git/clone" did not exist on "e7bc600304e98fa54184f4d7331b4e68016890b4"
Commit e7a87e71 authored by Jason Phang's avatar Jason Phang
Browse files

GPT-2 fixes

parent 9987203f
......@@ -9,17 +9,17 @@ from . import MODEL_REGISTRY
@MODEL_REGISTRY.register("gpt2")
class GPT2LM(LM):
def __init__(self, device="cpu"):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
self.device = device
@classmethod
def create_from_args(cls, arg_string):
def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def generate(self, context, max_gen_length):
context = torch.tensor([self.tok.encode(context.strip())], dtype=torch.long).to(self.device)
context = torch.tensor([self.tokenizer.encode(context.strip())], dtype=torch.long).to(self.device)
res = self.gpt2.generate(
context,
eos_token_id=self.tokenizer.eos_token_id,
......@@ -28,11 +28,11 @@ class GPT2LM(LM):
)
# chop off the prompt and the final eos token
return self.tok.decode(res[0][len(context[0]):-1]).strip()
return self.tokenizer.decode(res[0][len(context[0]):-1]).strip()
def loglikelihood(self, context, continuation):
inp = torch.tensor([self.tok.encode(context + continuation)], dtype=torch.long).to(self.device)
ctxlen = len(self.tok.encode(context.strip()))
inp = torch.tensor([self.tokenizer.encode(context + continuation)], dtype=torch.long).to(self.device)
ctxlen = len(self.tokenizer.encode(context.strip()))
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
......
......@@ -4,12 +4,12 @@ def simple_parse_args_string(args_string):
args1=val1,arg2=val2
Into a dictionary
"""
args_string = args_string.split()
args_string = args_string.strip()
if not args_string:
return {}
arg_list = args_string.split(",")
args_dict = {}
for arg, in arg_list:
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment