Commit aab91285 authored by Leo Gao's avatar Leo Gao
Browse files

Update gpt2 for efficiency and allow specifying model size

parent 4d8ed7d5
...@@ -7,44 +7,45 @@ from tqdm import tqdm ...@@ -7,44 +7,45 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
def __init__(self, device="cpu"): def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained(pretrained)
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu")) return cls(device=args.get("device", "cpu"), pretrained=args.get("pretrained", "gpt2"))
def loglikelihood(self, requests): def loglikelihood(self, requests):
res = [] res = []
# TODO: vectorize properly with torch.no_grad():
for context, continuation in tqdm(requests): # TODO: vectorize properly
# when too long to fit in context, truncate from the left for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
if context == "":
# end of text as context if context == "":
context_enc = [50256] # end of text as context
else: context_enc = [50256]
context_enc = self.tokenizer.encode(context) else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device) continuation_enc = self.tokenizer.encode(continuation)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024) inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab] cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all() greedy_tokens = logits.argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [batch, seq]
res.append((float(logits.sum()), bool(max_equal)))
res.append((float(logits.sum()), bool(max_equal)))
return res return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment