Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8ffa0e67
Commit
8ffa0e67
authored
Jul 28, 2023
by
baberabb
Browse files
updated anthropic to new API
parent
4e44f0aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
26 deletions
+51
-26
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+51
-26
No files found.
lm_eval/models/anthropic_llms.py
View file @
8ffa0e67
...
@@ -3,21 +3,27 @@ from lm_eval.api.model import LM
...
@@ -3,21 +3,27 @@ from lm_eval.api.model import LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
time
import
time
import
anthropic
from
lm_eval.logger
import
eval_logger
from
typing
import
List
,
Literal
def
anthropic_completion
(
def
anthropic_completion
(
client
,
model
,
prompt
,
max_tokens_to_sample
,
temperature
,
stop
client
:
anthropic
.
Anthropic
,
model
:
str
,
prompt
:
str
,
max_tokens_to_sample
:
int
,
temperature
:
float
,
stop
:
List
[
str
],
):
):
"""Query Anthropic API for completion.
"""Query Anthropic API for completion.
Retry with back-off until they respond
Retry with back-off until they respond
"""
"""
import
anthropic
backoff_time
=
3
backoff_time
=
3
while
True
:
while
True
:
try
:
try
:
response
=
client
.
completion
(
response
=
client
.
completion
s
.
create
(
prompt
=
f
"
{
anthropic
.
HUMAN_PROMPT
}
{
prompt
}{
anthropic
.
AI_PROMPT
}
"
,
prompt
=
f
"
{
anthropic
.
HUMAN_PROMPT
}
{
prompt
}{
anthropic
.
AI_PROMPT
}
"
,
model
=
model
,
model
=
model
,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...
@@ -26,35 +32,48 @@ def anthropic_completion(
...
@@ -26,35 +32,48 @@ def anthropic_completion(
max_tokens_to_sample
=
max_tokens_to_sample
,
max_tokens_to_sample
=
max_tokens_to_sample
,
temperature
=
temperature
,
temperature
=
temperature
,
)
)
return
response
[
"completion"
]
return
response
.
completion
except
RuntimeError
:
except
anthropic
.
RateLimitError
as
e
:
# TODO: I don't actually know what error Anthropic raises when it times out
eval_logger
.
warning
(
# So err update this error when we find out.
f
"RateLimitError occurred:
{
e
.
__cause__
}
\n
Retrying in
{
backoff_time
}
seconds"
import
traceback
)
traceback
.
print_exc
()
time
.
sleep
(
backoff_time
)
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
backoff_time
*=
1.5
except
anthropic
.
APIConnectionError
as
e
:
eval_logger
.
critical
(
f
"Server unreachable:
{
e
.
__cause__
}
"
)
break
except
anthropic
.
APIStatusError
as
e
:
eval_logger
.
critical
(
f
"API error
{
e
.
status_code
}
:
{
e
.
message
}
"
)
break
@
register_model
(
"anthropic"
)
@
register_model
(
"anthropic"
)
class
AnthropicLM
(
LM
):
class
AnthropicLM
(
LM
):
REQ_CHUNK_SIZE
=
20
REQ_CHUNK_SIZE
=
20
# TODO: not used
def
__init__
(
self
,
model
):
def
__init__
(
"""
self
,
batch_size
=
None
,
model
:
str
=
"claude-2.0"
,
max_tokens_to_sample
:
int
=
256
,
temperature
:
float
=
0.0
,
):
# TODO: remove batch_size
"""Anthropic API wrapper.
:param model: str
:param model: str
Anthropic model e.g. claude-instant-v1
Anthropic model e.g.
'
claude-instant-v1
', 'claude-2'
"""
"""
super
().
__init__
()
super
().
__init__
()
import
anthropic
self
.
model
=
model
self
.
model
=
model
self
.
client
=
anthropic
.
Client
(
os
.
environ
[
"ANTHROPIC_API_KEY"
])
self
.
client
=
anthropic
.
Anthropic
()
self
.
temperature
=
temperature
self
.
max_tokens_to_sample
=
max_tokens_to_sample
self
.
tokenizer
=
self
.
client
.
get_tokenizer
()
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise
NotImplementedError
(
"No idea about anthropic tokenization."
)
raise
NotImplementedError
(
"No idea about anthropic tokenization."
)
@
property
@
property
...
@@ -63,23 +82,23 @@ class AnthropicLM(LM):
...
@@ -63,23 +82,23 @@ class AnthropicLM(LM):
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
256
return
self
.
max_tokens_to_sample
@
property
@
property
def
batch_size
(
self
):
def
batch_size
(
self
):
# Isn't used because we override _loglikelihood_tokens
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
(
"No support for logits."
)
@
property
@
property
def
device
(
self
):
def
device
(
self
):
# Isn't used because we override _loglikelihood_tokens
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
(
"No support for logits."
)
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
)
->
List
[
int
]
:
r
aise
NotImplementedError
(
"No idea about anthropic tokenization."
)
r
eturn
self
.
tokenizer
.
encode
(
string
).
ids
def
tok_decode
(
self
,
tokens
)
:
def
tok_decode
(
self
,
tokens
:
List
[
int
])
->
str
:
r
aise
NotImplementedError
(
"No idea about anthropic tokenization."
)
r
eturn
self
.
tokenizer
.
decode
(
tokens
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
raise
NotImplementedError
(
"No support for logits."
)
...
@@ -99,8 +118,8 @@ class AnthropicLM(LM):
...
@@ -99,8 +118,8 @@ class AnthropicLM(LM):
client
=
self
.
client
,
client
=
self
.
client
,
model
=
self
.
model
,
model
=
self
.
model
,
prompt
=
inp
,
prompt
=
inp
,
max_tokens_to_sample
=
self
.
max_
g
en_to
ks
,
max_tokens_to_sample
=
self
.
max_
tok
en
s
_to
_sample
,
temperature
=
0.0
,
# TODO: implement non-greedy sampling for Anthropic
temperature
=
self
.
temperature
,
# TODO: implement non-greedy sampling for Anthropic
stop
=
until
,
stop
=
until
,
)
)
res
.
append
(
response
)
res
.
append
(
response
)
...
@@ -116,3 +135,9 @@ class AnthropicLM(LM):
...
@@ -116,3 +135,9 @@ class AnthropicLM(LM):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override greedy_until
# Isn't used because we override greedy_until
raise
NotImplementedError
()
raise
NotImplementedError
()
def
loglikelihood
(
self
,
requests
):
raise
NotImplementedError
(
"No support for logits."
)
def
loglikelihood_rolling
(
self
,
requests
):
raise
NotImplementedError
(
"No support for logits."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment