Commit 95bc8317 authored by Leo Gao's avatar Leo Gao
Browse files

hacky goose

parent cef6aa8d
......@@ -7,6 +7,7 @@ MODEL_REGISTRY = {
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
"dummy": dummy.DummyLM,
"gooseai": gpt3.GooseAILM,
}
......
......@@ -21,12 +21,13 @@ def get_result(response, ctxlen):
whether argmax matches given continuation exactly
"""
is_greedy = True
logprobs = response["logprobs"]["token_logprobs"]
logprobs = response["logprobs"]["token_logprobs"][:-1]
continuation_logprobs = sum(logprobs[ctxlen:])
print(logprobs[ctxlen:])
for i in range(ctxlen, len(response["logprobs"]["tokens"])):
token = response["logprobs"]["tokens"][i]
top_tokens = response["logprobs"]["top_logprobs"][i]
for i in range(ctxlen, len(response["logprobs"]["tokens"][:-1])):
token = response["logprobs"]["tokens"][:-1][i]
top_tokens = response["logprobs"]["top_logprobs"][:-1][i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
......@@ -35,6 +36,9 @@ def get_result(response, ctxlen):
return continuation_logprobs, is_greedy
class _goose:
choices: list
def oa_completion(**kwargs):
""" Query OpenAI API for completion.
......@@ -42,9 +46,23 @@ def oa_completion(**kwargs):
"""
import openai
backoff_time = 3
# print(kwargs)
if len(kwargs["prompt"]) > 1 and isinstance(kwargs["prompt"], list):
import dask
res = []
for pmpt in kwargs["prompt"]:
k = kwargs.copy()
k["prompt"] = [pmpt]
res.append(dask.delayed(oa_completion)(**k))
r = dask.compute(*res)
ob = _goose()
ob.choices = [x.choices[0] for x in r]
while True:
try:
return openai.Completion.create(**kwargs)
ret = openai.Completion.create(**kwargs)
# print(ret.choices[0])
return ret
except openai.error.OpenAIError:
import traceback
traceback.print_exc()
......@@ -55,7 +73,7 @@ def oa_completion(**kwargs):
class GPT3LM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, engine, truncate=False):
def __init__(self, engine, truncate=False, api_key=None, pass_strings=False):
"""
:param engine: str
......@@ -68,6 +86,7 @@ class GPT3LM(BaseLM):
import openai
self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.pass_strings = pass_strings
self.vocab_size = self.tokenizer.vocab_size
......@@ -78,7 +97,7 @@ class GPT3LM(BaseLM):
self.end_of_text_token_id = self.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])[0]
# Read from environment variable OPENAI_API_SECRET_KEY
openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]
openai.api_key = api_key or os.environ["OPENAI_API_SECRET_KEY"]
@property
def eot_token_id(self):
......@@ -130,6 +149,9 @@ class GPT3LM(BaseLM):
# TODO: the logic is much simpler if we just look at the length of continuation tokens
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - (self.max_length+1))
print(inp)
if self.pass_strings:
inp = self.tok_decode(inp)
inps.append(inp)
ctxlens.append(ctxlen)
......@@ -137,7 +159,7 @@ class GPT3LM(BaseLM):
engine=self.engine,
prompt=inps,
echo=True,
max_tokens=0, temperature=0.,
max_tokens=1,
logprobs=10,
)
......@@ -182,14 +204,14 @@ class GPT3LM(BaseLM):
for context, _ in chunk:
context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks):]
inps.append(inp)
inps.append(self.tok_decode(inp))
response = oa_completion(
engine=self.engine,
prompt=inps,
max_tokens=self.max_gen_toks,
temperature=0.,
logprobs=10,
# logprobs=10,
stop=until,
)
......@@ -213,3 +235,31 @@ class GPT3LM(BaseLM):
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
class GooseAILM(GPT3LM):
def __init__(self, engine, truncate=False, api_key=None, force_pile_tokenizer=False):
super().__init__(engine, truncate=truncate, api_key=api_key or os.environ["GOOSEAI_API_SECRET_KEY"], pass_strings=True)
self.REQ_CHUNK_SIZE = 1
import openai
openai.api_base = "https://api.goose.ai/v1"
from best_download import download_file
if engine == "gpt-neo-20b" or force_pile_tokenizer:
download_file("http://eaidata.bmk.sh/data/pile_tokenizer.json", expected_checksum="d27f071586925d23ef1c4acdee28fb8bf5d99c4a9d638b4e3b08812e3eae6ee7", local_file="pile_tokenizer.json")
self.tokenizer = transformers.PreTrainedTokenizerFast(tokenizer_file="pile_tokenizer.json")
@property
def max_length(self):
# Note: this is temporary, will be raised to 2048 in the future
return 1023
@property
def eot_token_id(self):
return self.tokenizer.eos_token_id
@property
def max_gen_toks(self):
return 64
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment