Commit 7f24a08b authored by Leo Gao's avatar Leo Gao
Browse files

Refactor LM organization for more reuse

parent e5066c69
import abc import abc
import random import random
from typing import Iterable
import numpy as np import numpy as np
import re import re
from tqdm import tqdm
from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean from lm_eval.metrics import mean, perplexity, weighted_perplexity, weighted_mean
from lm_eval import utils
class LM(abc.ABC): class LM(abc.ABC):
...@@ -96,20 +99,70 @@ class LM(abc.ABC): ...@@ -96,20 +99,70 @@ class LM(abc.ABC):
pass pass
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string, additional_config={}):
"""Constructor method, in case models need additional arguments args = utils.simple_parse_args_string(arg_string)
e.g. OpenAI API engine, paths for loading, other params args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
:param arg_string: str
Left up to individual model class to handle
"""
return cls()
def set_cache_hook(self, cache_hook): def set_cache_hook(self, cache_hook):
self.cache_hook = cache_hook self.cache_hook = cache_hook
class TokenizedLM(LM):
@abc.abstractmethod
def tok_encode(self, string: str): pass
@abc.abstractmethod
def tok_decode(self, tokens: Iterable[int]): pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, disable_tqdm=False): pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# TODO: enforce this somehow
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [self.eot_token_id]
else:
context_enc = self.tok_encode(context)
continuation_enc = self.tok_encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation answers, and evaluation methods. See BoolQ for a simple example implementation
......
...@@ -3,6 +3,7 @@ from . import gpt3 ...@@ -3,6 +3,7 @@ from . import gpt3
from . import dummy from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM,
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
"dummy": dummy.DummyLM, "dummy": dummy.DummyLM,
......
...@@ -2,218 +2,131 @@ import transformers ...@@ -2,218 +2,131 @@ import transformers
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval.base import LM from lm_eval.base import LM, TokenizedLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from abc import ABC, abstractmethod
from typing import Iterable
class GPT2LM(LM): class TorchLM(TokenizedLM):
MAX_GEN_TOKS = 256 @abstractmethod
def _model_generate(self, context, max_length, eos_token_id):
pass
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): @abstractmethod
super().__init__() def _model_call(self, inps):
"""
assert isinstance(device, str) inps: a torch tensor of shape [batch, sequence]
assert isinstance(pretrained, str) the size of sequence may vary from call to call
assert isinstance(batch_size, int)
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
self.VOCAB_SIZE = self.tokenizer.vocab_size
self.EOT_TOKEN_ID = self.tokenizer.eos_token_id
print(self.EOT_TOKEN_ID)
try:
self.max_length = self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparantly
self.max_length = self.gpt2.config.max_position_embeddings
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373]
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
self.batch_size = batch_size_per_gpu# * gpus
# TODO: fix multi-gpu
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer.encode(context, add_special_tokens=False)
continuation_enc = self.tokenizer.encode(continuation, add_special_tokens=False)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): returns: a torch tensor of shape [batch, sequence, vocab] with the
# TODO: Implement caching once we've confirmed the perplexity implementation logits retuned from the model
# TODO: automatic batch size detection for vectorization """
pass
loglikelihoods = [] # subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
with torch.no_grad(): # TODO: enforce this somehow
for string, in tqdm(requests):
rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows(
token_list=self.tokenizer.encode(string, add_special_tokens=False),
prefix_token=self.EOT_TOKEN_ID,
max_seq_len=self.max_length,
context_len=1,
)))
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for that
string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True)
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
with torch.no_grad():
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = x[1] + x[2] def _collate(x):
return (-len(toks), tuple(toks)) # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# TODO: automatic (variable) batch size detection for vectorization # - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
reord = utils.Reorderer(requests, _collate) # this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): # - any OOMs will happen right away rather than near the end
inps = []
contlens = [] toks = x[1] + x[2]
inplens = [] return (-len(toks), tuple(toks))
# TODO: automatic (variable) batch size detection for vectorization
reord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size):
inps = []
contlens = []
inplens = []
padding_length = None padding_length = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart # tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying # again because vectorizing is annoying
for _, context_enc, continuation_enc in chunk: for _, context_enc, continuation_enc in chunk:
# sanity check # sanity check
assert len(context_enc) > 0 assert len(context_enc) > 0
assert len(continuation_enc) > 0 assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length assert len(continuation_enc) <= self.max_length
# how this all works: # how this all works:
# CTX CONT # CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \ # gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.VOCAB_SIZE] slice # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9 # cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length+1):][:-1] (context_enc + continuation_enc)[-(self.max_length+1):][:-1]
, dtype=torch.long).to(self.device) , dtype=torch.long).to(self.device)
inplen, = inp.shape inplen, = inp.shape
cont = continuation_enc cont = continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one. # since in _collate we make sure length is descending, the longest is always the first one.
padding_length = padding_length if padding_length is not None else inplen padding_length = padding_length if padding_length is not None else inplen
# pad to length # pad to length
inp = torch.cat([ inp = torch.cat([
inp, # [seq] inp, # [seq]
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
], dim=0) ], dim=0)
inps.append(inp.unsqueeze(0)) inps.append(inp.unsqueeze(0))
contlens.append(cont) contlens.append(cont)
inplens.append(inplen) inplens.append(inplen)
multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab] multi_logits = F.log_softmax(self._model_call(torch.cat(inps, dim=0)), dim=-1).cpu() # [batch, seq, vocab]
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens): for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
contlen = len(cont_toks) contlen = len(cont_toks)
logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
# cont_toks :: [1, seq] # cont_toks :: [1, seq]
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist() #last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
answer = (float(logits.sum()), bool(max_equal)) answer = (float(logits.sum()), bool(max_equal))
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer) res.append(answer)
return reord.get_original(res) return reord.get_original(res)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
return self.gpt2(inps)[0][:, :, :50257]
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement fully general `until` that handles untils that are # TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly # multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res = [] res = []
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0], add_special_tokens=False) toks = self.tok_encode(x[0])
return (len(toks), x[0]) return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
...@@ -221,18 +134,13 @@ class GPT2LM(LM): ...@@ -221,18 +134,13 @@ class GPT2LM(LM):
for context, until in tqdm(reord.get_reordered()): for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): until = [until] if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context, add_special_tokens=False)[self.MAX_GEN_TOKS - self.max_length:]]).to(self.device) primary_until, = self.tok_encode(until[0])
primary_until, = self.tokenizer.encode(until[0], add_special_tokens=False) context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device)
cont = self.gpt2.generate( cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until)
context_enc,
max_length=context_enc.shape[1] + self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:]) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until: for term in until:
s = s.split(term)[0] s = s.split(term)[0]
...@@ -243,3 +151,83 @@ class GPT2LM(LM): ...@@ -243,3 +151,83 @@ class GPT2LM(LM):
res.append(s) res.append(s)
return reord.get_original(res) return reord.get_original(res)
class HFLM(TorchLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
super().__init__()
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(pretrained, revision=revision +("/" + subfolder if subfolder is not None else "")).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
self.eot_token_id = self.tokenizer.eos_token_id # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
self.max_gen_toks = 256
try:
self.max_length = self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparantly
self.max_length = self.gpt2.config.max_position_embeddings
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], self.tokenizer.encode('hello\n\nhello')
# multithreading and batching
gpus = torch.cuda.device_count()
batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
self.batch_size = batch_size_per_gpu# * gpus
# TODO: fix multi-gpu
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
)
# for backwards compability
GPT2LM = HFLM
\ No newline at end of file
import os import os
import numpy as np import numpy as np
import transformers import transformers
from lm_eval.base import LM from lm_eval.base import LM, TokenizedLM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import time import time
...@@ -35,11 +35,8 @@ def oa_completion(**kwargs): ...@@ -35,11 +35,8 @@ def oa_completion(**kwargs):
backoff_time *= 1.5 backoff_time *= 1.5
class GPT3LM(LM): class GPT3LM(TokenizedLM):
MAX_LENGTH = 2048
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
MAX_GEN_TOKS = 256
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
""" """
...@@ -50,10 +47,15 @@ class GPT3LM(LM): ...@@ -50,10 +47,15 @@ class GPT3LM(LM):
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.vocab_size = self.tokenizer.vocab_size
self.eot_token_id = self.tokenizer.eos_token_id
self.max_gen_toks = 256
self.max_length = 2048
# to make the annoying "Using pad_token, but it is not set yet." error go away # to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
...@@ -63,27 +65,12 @@ class GPT3LM(LM): ...@@ -63,27 +65,12 @@ class GPT3LM(LM):
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"] openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
@classmethod def tok_encode(self, string: str):
def create_from_arg_string(cls, arg_string, additional_config={}): return self.tokenizer.encode(string, add_special_tokens=False)
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} def tok_decode(self, tokens):
return cls(**args, **args2) return self.tokenizer.decode(tokens)
def loglikelihood(self, requests):
new_reqs = []
for context, continuation in requests:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
# TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing # TODO: switch implementation to use _loglikelihood_tokens rather than having it do its own thing
...@@ -94,7 +81,7 @@ class GPT3LM(LM): ...@@ -94,7 +81,7 @@ class GPT3LM(LM):
rolling_token_windows = utils.get_rolling_token_windows( rolling_token_windows = utils.get_rolling_token_windows(
token_list=encoded, token_list=encoded,
prefix_token=self.end_of_text_token_id, prefix_token=self.end_of_text_token_id,
max_seq_len=self.MAX_LENGTH, max_seq_len=self.max_length,
context_len=1, context_len=1,
) )
string_loglikelihoods = [] string_loglikelihoods = []
...@@ -109,8 +96,28 @@ class GPT3LM(LM): ...@@ -109,8 +96,28 @@ class GPT3LM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests): def get_token_logprobs(self, input_tokens, pred_tokens):
import openai pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = [] res = []
def _collate(x): def _collate(x):
...@@ -122,12 +129,12 @@ class GPT3LM(LM): ...@@ -122,12 +129,12 @@ class GPT3LM(LM):
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk in tqdm(list(utils.chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE)), disable=disable_tqdm):
inps = [] inps = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.max_length:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.max_length)
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
...@@ -151,34 +158,13 @@ class GPT3LM(LM): ...@@ -151,34 +158,13 @@ class GPT3LM(LM):
return reord.get_original(res) return reord.get_original(res)
def get_token_logprobs(self, input_tokens, pred_tokens):
pred_start = len(input_tokens) - len(pred_tokens) + 1
# We're going to stitch together the input_tokens and pred_tokens
# In the longest case, this gets us to length = max_seq_len+1 (which the API works with)
assert input_tokens[pred_start:] == pred_tokens[:-1]
token_ids = input_tokens + [pred_tokens[-1]]
response = oa_completion(
engine=self.engine,
prompt=token_ids,
max_tokens=0,
temperature=0.0,
logprobs=0,
echo=True,
)
logprobs = np.array(response["choices"][0]["logprobs"]["token_logprobs"][pred_start:])
positions = np.arange(pred_start-1, pred_start-1 + len(token_ids[pred_start:]))
return {
"logprobs": logprobs,
"positions": positions,
}
def greedy_until(self, requests): def greedy_until(self, requests):
if not requests: return [] if not requests: return []
import openai import openai
res = [] res = []
def _collate(x): def _collate(x):
toks = self.tokenizer.encode(x[0]) toks = self.tok_encode(x[0])
return (len(toks), x[0]) return (len(toks), x[0])
reord = utils.Reorderer(requests, _collate) reord = utils.Reorderer(requests, _collate)
...@@ -199,14 +185,14 @@ class GPT3LM(LM): ...@@ -199,14 +185,14 @@ class GPT3LM(LM):
for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))): for chunk, until in tqdm(list(sameuntil_chunks(reord.get_reordered(), self.REQ_CHUNK_SIZE))):
inps = [] inps = []
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tokenizer.encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append(inp) inps.append(inp)
response = oa_completion( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.max_gen_toks,
temperature=0., temperature=0.,
logprobs=10, logprobs=10,
stop=until stop=until
......
...@@ -85,7 +85,7 @@ def test_gpt3_perplexity(): ...@@ -85,7 +85,7 @@ def test_gpt3_perplexity():
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
# Hack: modify gpt3 to have shorter context length to induce rolling windows # Hack: modify gpt3 to have shorter context length to induce rolling windows
gpt3.MAX_LENGTH = 5 gpt3.max_length = 5
perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0] perplexity = gpt3.loglikelihood_rolling([(test_string,)])[0]
tgt = -101.93490880000002 tgt = -101.93490880000002
assert perplexity == pytest.approx(tgt, rel=1e-3) assert perplexity == pytest.approx(tgt, rel=1e-3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment