Unverified Commit 046ea6e2 authored by Zach Schillaci's avatar Zach Schillaci Committed by GitHub
Browse files

Generic decorator for handling rate limit errors (#1109)



* Add retry error handler

* fixup! Add retry error handler

* Move to utils.py

* Run isort on utils.py

* Catch multiple exceptions

* Update LMs with exception handler

* Fixes to anthropic retry handler

* fix callback kwarg

* Update textsynth.py

* fix python 3.8 incompatibility

* fix indenterror I introduced

* placate linter?

* Update on_exception_callback kwarg name

* fixup! Merge branch 'main' into add-retry-error-handler

* fixup! fixup! Merge branch 'main' into add-retry-error-handler

* Merge conflicts are fun

* Run pre-commit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 09493fd2
import time
from typing import Any, List, Tuple
from tqdm import tqdm
......@@ -6,6 +5,7 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
eval_logger = utils.eval_logger
......@@ -48,26 +48,30 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
backoff_time: float = 3
while True:
try:
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# (e.g. gsm8k's ":") may truncate a lot of the input.
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
)
time.sleep(backoff_time)
backoff_time *= 1.5
def _exception_callback(e: Exception, sleep_time: float) -> None:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
)
@retry_on_specific_exceptions(
on_exceptions=[anthropic.RateLimitError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# (e.g. gsm8k's ":") may truncate a lot of the input.
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
return response.completion
return completion()
@register_model("anthropic")
......@@ -144,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
if not requests:
return []
......
import copy
import os
import time
from collections import defaultdict
from importlib.util import find_spec
from typing import List, Optional, Tuple
......@@ -10,6 +9,7 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
def get_result(response, ctxlen: int) -> Tuple[float, bool]:
......@@ -53,16 +53,20 @@ def oa_completion(**kwargs):
else:
import openai
backoff_time = 3
while True:
try:
return openai.completions.create(**kwargs)
except openai.OpenAIError:
import traceback
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return openai.completions.create(**kwargs)
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
return completion()
@register_model("openai-completions")
......@@ -337,20 +341,20 @@ def oa_chat_completion(client, **kwargs):
else:
import openai
async def _get_completions(**kwargs):
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
backoff_time = 3
while True:
try:
return client.chat.completions.create(**kwargs)
except openai.OpenAIError:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return client.chat.completions.create(**kwargs)
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
return completion()
@register_model("openai-chat-completions", "local-chat-completions")
......
......@@ -13,13 +13,13 @@ Homepage: https://textsynth.com/index.html
"""
import logging
import os
import time
import requests as _requests
from tqdm import tqdm
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__)
......@@ -29,21 +29,26 @@ def textsynth_completion(**kwargs):
"""Query TextSynth API for completion.
Retry with back-off until they respond.
"""
backoff_time = 3
while True:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
def _exception_callback(e: Exception, sleep_time: float) -> None:
import traceback
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[_requests.exceptions.RequestException],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return _requests.post(**kwargs)
return completion()
@register_model("textsynth")
class TextSynthLM(LM):
def __init__(self, engine, truncate: bool = False) -> None:
def __init__(self, engine, truncate: bool = False, **kwargs) -> None:
"""
:param engine: str
TextSynth API engine (e.g. `gptj_6B`)
......
......@@ -10,8 +10,10 @@ import pathlib
import re
import subprocess
import sys
import time
from functools import wraps
from itertools import islice
from typing import Any, Callable, Iterator, List, Literal, Union
from typing import Any, Callable, Iterator, List, Literal, Optional, Type, Union
import torch
import transformers
......@@ -714,3 +716,43 @@ def divide(iterable, n) -> List[Iterator]:
ret.append(iter(seq[start:stop]))
return ret
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment