Unverified Commit 046ea6e2 authored by Zach Schillaci's avatar Zach Schillaci Committed by GitHub
Browse files

Generic decorator for handling rate limit errors (#1109)



* Add retry error handler

* fixup! Add retry error handler

* Move to utils.py

* Run isort on utils.py

* Catch multiple exceptions

* Update LMs with exception handler

* Fixes to anthropic retry handler

* fix callback kwarg

* Update textsynth.py

* fix python 3.8 incompatibility

* fix indenterror I introduced

* placate linter?

* Update on_exception_callback kwarg name

* fixup! Merge branch 'main' into add-retry-error-handler

* fixup! fixup! Merge branch 'main' into add-retry-error-handler

* Merge conflicts are fun

* Run pre-commit

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 09493fd2
import time
from typing import Any, List, Tuple from typing import Any, List, Tuple
from tqdm import tqdm from tqdm import tqdm
...@@ -6,6 +5,7 @@ from tqdm import tqdm ...@@ -6,6 +5,7 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -48,9 +48,17 @@ def anthropic_completion( ...@@ -48,9 +48,17 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`", please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
) )
backoff_time: float = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True: eval_logger.warning(
try: f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
)
@retry_on_specific_exceptions(
on_exceptions=[anthropic.RateLimitError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
response = client.completions.create( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
...@@ -62,12 +70,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -62,12 +70,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
**kwargs, **kwargs,
) )
return response.completion return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning( return completion()
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
)
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("anthropic") @register_model("anthropic")
...@@ -144,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -144,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
if not requests: if not requests:
return [] return []
......
import copy import copy
import os import os
import time
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec from importlib.util import find_spec
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -10,6 +9,7 @@ from tqdm import tqdm ...@@ -10,6 +9,7 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -53,16 +53,20 @@ def oa_completion(**kwargs): ...@@ -53,16 +53,20 @@ def oa_completion(**kwargs):
else: else:
import openai import openai
backoff_time = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True:
try:
return openai.completions.create(**kwargs)
except openai.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return openai.completions.create(**kwargs)
return completion()
@register_model("openai-completions") @register_model("openai-completions")
...@@ -337,20 +341,20 @@ def oa_chat_completion(client, **kwargs): ...@@ -337,20 +341,20 @@ def oa_chat_completion(client, **kwargs):
else: else:
import openai import openai
async def _get_completions(**kwargs): def _exception_callback(e: Exception, sleep_time: float) -> None:
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
backoff_time = 3
while True:
try:
return client.chat.completions.create(**kwargs)
except openai.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return client.chat.completions.create(**kwargs)
return completion()
@register_model("openai-chat-completions", "local-chat-completions") @register_model("openai-chat-completions", "local-chat-completions")
......
...@@ -13,13 +13,13 @@ Homepage: https://textsynth.com/index.html ...@@ -13,13 +13,13 @@ Homepage: https://textsynth.com/index.html
""" """
import logging import logging
import os import os
import time
import requests as _requests import requests as _requests
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -29,21 +29,26 @@ def textsynth_completion(**kwargs): ...@@ -29,21 +29,26 @@ def textsynth_completion(**kwargs):
"""Query TextSynth API for completion. """Query TextSynth API for completion.
Retry with back-off until they respond. Retry with back-off until they respond.
""" """
backoff_time = 3
while True: def _exception_callback(e: Exception, sleep_time: float) -> None:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[_requests.exceptions.RequestException],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return _requests.post(**kwargs)
return completion()
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate: bool = False) -> None: def __init__(self, engine, truncate: bool = False, **kwargs) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
......
...@@ -10,8 +10,10 @@ import pathlib ...@@ -10,8 +10,10 @@ import pathlib
import re import re
import subprocess import subprocess
import sys import sys
import time
from functools import wraps
from itertools import islice from itertools import islice
from typing import Any, Callable, Iterator, List, Literal, Union from typing import Any, Callable, Iterator, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
...@@ -714,3 +716,43 @@ def divide(iterable, n) -> List[Iterator]: ...@@ -714,3 +716,43 @@ def divide(iterable, n) -> List[Iterator]:
ret.append(iter(seq[start:stop])) ret.append(iter(seq[start:stop]))
return ret return ret
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment