Unverified Commit 5a49b2a3 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #738 from baberabb/master_anthropic

[Main] updated to new anthropic API
parents fe803c29 d504944b
......@@ -4,7 +4,9 @@ from tqdm import tqdm
import time
def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperature, stop):
def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
......@@ -14,7 +16,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
backoff_time = 3
while True:
try:
response = client.completion(
response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
......@@ -24,7 +26,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
temperature=temperature,
)
print(response)
return response["completion"]
return response.completion
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
......@@ -38,7 +40,7 @@ def anthropic_completion(client, model, prompt, max_tokens_to_sample, temperatur
class AnthropicLM(BaseLM):
REQ_CHUNK_SIZE = 20
def __init__(self, model):
def __init__(self, model="claude-2"):
"""
:param model: str
......@@ -46,8 +48,9 @@ class AnthropicLM(BaseLM):
"""
super().__init__()
import anthropic
self.model = model
self.client = anthropic.Client(os.environ['ANTHROPIC_API_KEY'])
self.client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
@property
def eot_token_id(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment