Commit eabadf46 authored by baberabb's avatar baberabb
Browse files

added type hints

parent e9b938f2
import os
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
from typing import List, Literal, Any from typing import List, Literal, Any, Tuple, Optional
def anthropic_completion( def anthropic_completion(
...@@ -15,10 +14,25 @@ def anthropic_completion( ...@@ -15,10 +14,25 @@ def anthropic_completion(
temperature: float, temperature: float,
stop: List[str], stop: List[str],
**kwargs: Any, **kwargs: Any,
): ) -> str:
"""Query Anthropic API for completion. """Wrapper function around the Anthropic completion API client with exponential back-off
in case of RateLimitError.
Retry with back-off until they respond
params:
client: anthropic.Anthropic
Anthropic API client
model: str
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
prompt: str
Prompt to feed to the model
max_tokens_to_sample: int
Maximum number of tokens to sample from the model
temperature: float
Sampling temperature
stop: List[str]
List of stop sequences
kwargs: Any
Additional model_args to pass to the API client
""" """
try: try:
...@@ -29,7 +43,7 @@ def anthropic_completion( ...@@ -29,7 +43,7 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`", please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
) )
backoff_time = 3 backoff_time: float = 3
while True: while True:
try: try:
response = client.completions.create( response = client.completions.create(
...@@ -94,15 +108,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -94,15 +108,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
@property @property
def eot_token_id(self): def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30] # Not sure but anthropic.HUMAN_PROMPT ?
raise NotImplementedError("No idea about anthropic tokenization.") raise NotImplementedError("No idea about anthropic tokenization.")
@property @property
def max_length(self): def max_length(self) -> int:
return 2048 return 2048
@property @property
def max_gen_toks(self): def max_gen_toks(self) -> int:
return self.max_tokens_to_sample return self.max_tokens_to_sample
@property @property
...@@ -124,14 +138,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -124,14 +138,15 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests): def greedy_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
requests = [req.args for req in requests] _requests: List[Tuple[str, dict]] = [req.args for req in requests]
res = [] res = []
for request in tqdm(requests): for request in tqdm(_requests):
try: try:
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
...@@ -145,16 +160,16 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -145,16 +160,16 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
prompt=inp, prompt=inp,
max_tokens_to_sample=max_gen_toks, max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until, # type: ignore
**self.kwargs, **self.kwargs,
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response) self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e: # noqa: F821 except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}") eval_logger.critical(f"Server unreachable: {e.__cause__}")
break break
except anthropic.APIStatusError as e: # noqa: F821 except anthropic.APIStatusError as e: # type: ignore # noqa: F821
eval_logger.critical(f"API error {e.status_code}: {e.message}") eval_logger.critical(f"API error {e.status_code}: {e.message}")
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment