Commit b1f7284e authored by Leo Gao's avatar Leo Gao
Browse files

Add retry with backoff for GPT3

parent c55e8237
...@@ -3,6 +3,7 @@ import transformers ...@@ -3,6 +3,7 @@ import transformers
from lm_eval.base import LM from lm_eval.base import LM
from lm_eval import utils from lm_eval import utils
from tqdm import tqdm from tqdm import tqdm
import time
def get_result(response, ctxlen): def get_result(response, ctxlen):
...@@ -21,6 +22,18 @@ def get_result(response, ctxlen): ...@@ -21,6 +22,18 @@ def get_result(response, ctxlen):
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
def oa_completion(**kwargs):
import openai
backoff_time = 3
while True:
try:
return openai.Completion.create(**kwargs)
except openai.error.OpenAIError:
time.sleep(backoff_time)
backoff_time *= 1.5
class GPT3LM(LM): class GPT3LM(LM):
MAX_LENGTH = 2048 MAX_LENGTH = 2048
...@@ -67,7 +80,7 @@ class GPT3LM(LM): ...@@ -67,7 +80,7 @@ class GPT3LM(LM):
inps.append(inp) inps.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
response = openai.Completion.create( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=inps, prompt=inps,
echo=True, echo=True,
...@@ -89,7 +102,7 @@ class GPT3LM(LM): ...@@ -89,7 +102,7 @@ class GPT3LM(LM):
inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):] inp = context_enc[-(self.MAX_LENGTH - self.MAX_GEN_TOKS):]
ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS)) ctxlen = len(context_enc) - max(0, len(context_enc) - (self.MAX_LENGTH - self.MAX_GEN_TOKS))
response = openai.Completion.create( response = oa_completion(
engine=self.engine, engine=self.engine,
prompt=[inp], prompt=[inp],
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.MAX_GEN_TOKS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment