Commit 8ffa0e67 authored by baberabb's avatar baberabb
Browse files

updated anthropic to new API

parent 4e44f0aa
......@@ -3,21 +3,27 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal
def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop
client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
"""
import anthropic
backoff_time = 3
while True:
try:
response = client.completion(
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
......@@ -26,35 +32,48 @@ def anthropic_completion(
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
)
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
)
time.sleep(backoff_time)
backoff_time *= 1.5
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
@register_model("anthropic")
class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(self, model):
"""
def __init__(
self,
batch_size=None,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0.0,
): # TODO: remove batch_size
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. claude-instant-v1
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
@property
def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.")
@property
......@@ -63,23 +82,23 @@ class AnthropicLM(LM):
@property
def max_gen_toks(self):
return 256
return self.max_tokens_to_sample
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
raise NotImplementedError("No support for logits.")
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string).ids
def tok_decode(self, tokens):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
......@@ -99,8 +118,8 @@ class AnthropicLM(LM):
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=self.max_gen_toks,
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic
max_tokens_to_sample=self.max_tokens_to_sample,
temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until,
)
res.append(response)
......@@ -116,3 +135,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment