Commit 8ffa0e67 authored by baberabb's avatar baberabb
Browse files

updated anthropic to new API

parent 4e44f0aa
...@@ -3,21 +3,27 @@ from lm_eval.api.model import LM ...@@ -3,21 +3,27 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal
def anthropic_completion( def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
): ):
"""Query Anthropic API for completion. """Query Anthropic API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import anthropic
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
response = client.completion( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...@@ -26,35 +32,48 @@ def anthropic_completion( ...@@ -26,35 +32,48 @@ def anthropic_completion(
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
) )
return response["completion"] return response.completion
except RuntimeError: except anthropic.RateLimitError as e:
# TODO: I don't actually know what error Anthropic raises when it times out eval_logger.warning(
# So err update this error when we find out. f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
import traceback )
traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
@register_model("anthropic") @register_model("anthropic")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(self, model): def __init__(
""" self,
batch_size=None,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0.0,
): # TODO: remove batch_size
"""Anthropic API wrapper.
:param model: str :param model: str
Anthropic model e.g. claude-instant-v1 Anthropic model e.g. 'claude-instant-v1', 'claude-2'
""" """
super().__init__() super().__init__()
import anthropic
self.model = model self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
@property @property
def eot_token_id(self): def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.") raise NotImplementedError("No idea about anthropic tokenization.")
@property @property
...@@ -63,23 +82,23 @@ class AnthropicLM(LM): ...@@ -63,23 +82,23 @@ class AnthropicLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return self.max_tokens_to_sample
@property @property
def batch_size(self): def batch_size(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
@property @property
def device(self): def device(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str): def tok_encode(self, string: str) -> List[int]:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.encode(string).ids
def tok_decode(self, tokens): def tok_decode(self, tokens: List[int]) -> str:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
...@@ -99,8 +118,8 @@ class AnthropicLM(LM): ...@@ -99,8 +118,8 @@ class AnthropicLM(LM):
client=self.client, client=self.client,
model=self.model, model=self.model,
prompt=inp, prompt=inp,
max_tokens_to_sample=self.max_gen_toks, max_tokens_to_sample=self.max_tokens_to_sample,
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic temperature=self.temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until,
) )
res.append(response) res.append(response)
...@@ -116,3 +135,9 @@ class AnthropicLM(LM): ...@@ -116,3 +135,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment