Unverified Commit 465c695b authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #710 from baberabb/big-refactor_claude

[Refactor] Updated anthropic to new API
parents 6efc8d5e 471297ba
...@@ -55,7 +55,7 @@ jobs: ...@@ -55,7 +55,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM ...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal, Any
def anthropic_completion( def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
): ):
"""Query Anthropic API for completion. """Query Anthropic API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import anthropic
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
response = client.completion( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...@@ -25,36 +32,53 @@ def anthropic_completion( ...@@ -25,36 +32,53 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop, stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
) )
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("anthropic") @register_model("anthropic")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(self, model): def __init__(
""" self,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str :param model: str
Anthropic model e.g. claude-instant-v1 Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
""" """
super().__init__() super().__init__()
import anthropic
self.model = model self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) # defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property @property
def eot_token_id(self): def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.") raise NotImplementedError("No idea about anthropic tokenization.")
@property @property
...@@ -63,23 +87,23 @@ class AnthropicLM(LM): ...@@ -63,23 +87,23 @@ class AnthropicLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return self.max_tokens_to_sample
@property @property
def batch_size(self): def batch_size(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
@property @property
def device(self): def device(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str): def tok_encode(self, string: str) -> List[int]:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.encode(string).ids
def tok_decode(self, tokens): def tok_decode(self, tokens: List[int]) -> str:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
...@@ -92,20 +116,31 @@ class AnthropicLM(LM): ...@@ -92,20 +116,31 @@ class AnthropicLM(LM):
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] try:
request_args = request[1] inp = request[0]
until = request_args["until"] request_args = request[1]
response = anthropic_completion( # generation_kwargs
client=self.client, until = request_args.get("until")
model=self.model, max_gen_toks = request_args.get("max_gen_toks", self.max_length)
prompt=inp, temperature = request_args.get("temperature", self.temperature)
max_tokens_to_sample=self.max_gen_toks, response = anthropic_completion(
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic client=self.client,
stop=until, model=self.model,
) prompt=inp,
res.append(response) max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
self.cache_hook.add_partial("greedy_until", request, response) stop=until,
**self.kwargs,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res return res
...@@ -116,3 +151,9 @@ class AnthropicLM(LM): ...@@ -116,3 +151,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment