Unverified Commit 465c695b authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #710 from baberabb/big-refactor_claude

[Refactor] Updated anthropic to new API
parents 6efc8d5e 471297ba
......@@ -55,7 +55,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
......@@ -3,21 +3,28 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal, Any
def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop
client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
"""
import anthropic
backoff_time = 3
while True:
try:
response = client.completion(
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
......@@ -25,36 +32,53 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
)
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("anthropic")
class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
"""
REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(
self,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. claude-instant-v1
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"])
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property
def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.")
@property
......@@ -63,23 +87,23 @@ class AnthropicLM(LM):
@property
def max_gen_toks(self):
return 256
return self.max_tokens_to_sample
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
raise NotImplementedError("No support for logits.")
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string).ids
def tok_decode(self, tokens):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
......@@ -92,20 +116,31 @@ class AnthropicLM(LM):
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=self.max_gen_toks,
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic
stop=until,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
try:
inp = request[0]
request_args = request[1]
# generation_kwargs
until = request_args.get("until")
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
temperature = request_args.get("temperature", self.temperature)
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
stop=until,
**self.kwargs,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res
......@@ -116,3 +151,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment