Commit aab91285 authored by Leo Gao's avatar Leo Gao
Browse files

Update gpt2 for efficiency and allow specifying model size

parent 4d8ed7d5
...@@ -7,20 +7,21 @@ from tqdm import tqdm ...@@ -7,20 +7,21 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
def __init__(self, device="cpu"): def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
with torch.no_grad():
# TODO: vectorize properly # TODO: vectorize properly
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment