import os from lm_eval.api.model import LM from lm_eval.api.registry import register_model from tqdm import tqdm import time def anthropic_completion( client, model, prompt, max_tokens_to_sample, temperature, stop ): """Query Anthropic API for completion. Retry with back-off until they respond """ import anthropic backoff_time = 3 while True: try: response = client.completion( prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", model=model, # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences # (e.g. gsm8k's ":") may truncate a lot of the input. stop_sequences=[anthropic.HUMAN_PROMPT] + stop, max_tokens_to_sample=max_tokens_to_sample, temperature=temperature, ) return response["completion"] except RuntimeError: # TODO: I don't actually know what error Anthropic raises when it times out # So err update this error when we find out. import traceback traceback.print_exc() time.sleep(backoff_time) backoff_time *= 1.5 @register_model("anthropic") class AnthropicLM(LM): REQ_CHUNK_SIZE = 20 def __init__(self, model): """ :param model: str Anthropic model e.g. claude-instant-v1 """ super().__init__() import anthropic self.model = model self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) @property def eot_token_id(self): raise NotImplementedError("No idea about anthropic tokenization.") @property def max_length(self): return 2048 @property def max_gen_toks(self): return 256 @property def batch_size(self): # Isn't used because we override _loglikelihood_tokens raise NotImplementedError() @property def device(self): # Isn't used because we override _loglikelihood_tokens raise NotImplementedError() def tok_encode(self, string: str): raise NotImplementedError("No idea about anthropic tokenization.") def tok_decode(self, tokens): raise NotImplementedError("No idea about anthropic tokenization.") def _loglikelihood_tokens(self, requests, disable_tqdm=False): raise NotImplementedError("No support for logits.") def greedy_until(self, requests): if not requests: return [] requests = [req.args for req in requests] res = [] for request in tqdm(requests): inp = request[0] request_args = request[1] until = request_args["until"] response = anthropic_completion( client=self.client, model=self.model, prompt=inp, max_tokens_to_sample=self.max_gen_toks, temperature=0.0, # TODO: implement non-greedy sampling for Anthropic stop=until, ) res.append(response) self.cache_hook.add_partial("greedy_until", request, response) return res def _model_call(self, inps): # Isn't used because we override _loglikelihood_tokens raise NotImplementedError() def _model_generate(self, context, max_length, eos_token_id): # Isn't used because we override greedy_until raise NotImplementedError()