Unverified Commit cc4eab6a authored by Jason Phang's avatar Jason Phang Committed by GitHub
Browse files

Add Anthropic support (#562)

* add anthropic support

* move requirement
parent 9c89d1fc
from . import gpt2
from . import gpt3
from . import anthropic_llms
from . import huggingface
from . import textsynth
from . import dummy
......@@ -11,6 +12,7 @@ MODEL_REGISTRY = {
"hf-seq2seq": huggingface.AutoSeq2SeqLM,
"gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM,
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
}
......
import os
from lm_eval.base import BaseLM
from tqdm import tqdm
import time
def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperature, stop):
"""Query Anthropic API for completion.
Retry with back-off until they respond
"""
import anthropic
backoff_time = 3
while True:
try:
response = client.completion(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# (e.g. gsm8k's ":") may truncate a lot of the input.
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature,
)
print(response)
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5
class AnthropicLM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
"""
:param model: str
Anthropic model e.g. claude-instant-v1
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
@property
def eot_token_id(self):
raise NotImplementedError("No idea about anthropic tokenization.")
@property
def max_length(self):
return 2048
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def tok_encode(self, string: str):
raise NotImplementedError("No idea about anthropic tokenization.")
def tok_decode(self, tokens):
raise NotImplementedError("No idea about anthropic tokenization.")
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests):
if not requests:
return []
res = []
for request in tqdm(requests):
inp = request[0]
request_args = request[1]
until = request_args["until"]
response = anthropic_completion(
client=self.client,
model=self.model,
prompt=inp,
max_tokens_to_sample=self.max_gen_toks,
temperature=0.0,
stop=until,
)
res.append(response)
return res
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
......@@ -45,5 +45,6 @@ setuptools.setup(
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"auto-gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"],
"anthropic": ["anthropic"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment