Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
465c695b
Unverified
Commit
465c695b
authored
Jul 31, 2023
by
Hailey Schoelkopf
Committed by
GitHub
Jul 31, 2023
Browse files
Merge pull request #710 from baberabb/big-refactor_claude
[Refactor] Updated anthropic to new API
parents
6efc8d5e
471297ba
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
40 deletions
+81
-40
.github/workflows/unit_tests.yml
.github/workflows/unit_tests.yml
+1
-1
lm_eval/models/anthropic_llms.py
lm_eval/models/anthropic_llms.py
+80
-39
No files found.
.github/workflows/unit_tests.yml
View file @
465c695b
...
...
@@ -55,7 +55,7 @@ jobs:
-
name
:
Install dependencies
run
:
|
python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e '.[testing
,anthropic,sentencepiece
]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
...
...
lm_eval/models/anthropic_llms.py
View file @
465c695b
...
...
@@ -3,21 +3,28 @@ from lm_eval.api.model import LM
from
lm_eval.api.registry
import
register_model
from
tqdm
import
tqdm
import
time
import
anthropic
from
lm_eval.logger
import
eval_logger
from
typing
import
List
,
Literal
,
Any
def
anthropic_completion
(
client
,
model
,
prompt
,
max_tokens_to_sample
,
temperature
,
stop
client
:
anthropic
.
Anthropic
,
model
:
str
,
prompt
:
str
,
max_tokens_to_sample
:
int
,
temperature
:
float
,
stop
:
List
[
str
],
**
kwargs
:
Any
,
):
"""Query Anthropic API for completion.
Retry with back-off until they respond
"""
import
anthropic
backoff_time
=
3
while
True
:
try
:
response
=
client
.
completion
(
response
=
client
.
completion
s
.
create
(
prompt
=
f
"
{
anthropic
.
HUMAN_PROMPT
}
{
prompt
}{
anthropic
.
AI_PROMPT
}
"
,
model
=
model
,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...
...
@@ -25,36 +32,53 @@ def anthropic_completion(
stop_sequences
=
[
anthropic
.
HUMAN_PROMPT
]
+
stop
,
max_tokens_to_sample
=
max_tokens_to_sample
,
temperature
=
temperature
,
**
kwargs
,
)
return
response
.
completion
except
anthropic
.
RateLimitError
as
e
:
eval_logger
.
warning
(
f
"RateLimitError occurred:
{
e
.
__cause__
}
\n
Retrying in
{
backoff_time
}
seconds"
)
return
response
[
"completion"
]
except
RuntimeError
:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import
traceback
traceback
.
print_exc
()
time
.
sleep
(
backoff_time
)
backoff_time
*=
1.5
@
register_model
(
"anthropic"
)
class
AnthropicLM
(
LM
):
REQ_CHUNK_SIZE
=
20
def
__init__
(
self
,
model
):
"""
REQ_CHUNK_SIZE
=
20
# TODO: not used
def
__init__
(
self
,
batch_size
:
int
=
1
,
model
:
str
=
"claude-2.0"
,
max_tokens_to_sample
:
int
=
256
,
temperature
:
float
=
0
,
# defaults to 1
**
kwargs
,
# top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str
Anthropic model e.g. claude-instant-v1
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
"""
super
().
__init__
()
import
anthropic
self
.
model
=
model
self
.
client
=
anthropic
.
Client
(
os
.
environ
[
"ANTHROPIC_API_KEY"
])
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self
.
client
=
anthropic
.
Anthropic
()
self
.
temperature
=
temperature
self
.
max_tokens_to_sample
=
max_tokens_to_sample
self
.
tokenizer
=
self
.
client
.
get_tokenizer
()
self
.
kwargs
=
kwargs
@
property
def
eot_token_id
(
self
):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise
NotImplementedError
(
"No idea about anthropic tokenization."
)
@
property
...
...
@@ -63,23 +87,23 @@ class AnthropicLM(LM):
@
property
def
max_gen_toks
(
self
):
return
256
return
self
.
max_tokens_to_sample
@
property
def
batch_size
(
self
):
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
(
"No support for logits."
)
@
property
def
device
(
self
):
# Isn't used because we override _loglikelihood_tokens
raise
NotImplementedError
()
raise
NotImplementedError
(
"No support for logits."
)
def
tok_encode
(
self
,
string
:
str
):
r
aise
NotImplementedError
(
"No idea about anthropic tokenization."
)
def
tok_encode
(
self
,
string
:
str
)
->
List
[
int
]
:
r
eturn
self
.
tokenizer
.
encode
(
string
).
ids
def
tok_decode
(
self
,
tokens
)
:
r
aise
NotImplementedError
(
"No idea about anthropic tokenization."
)
def
tok_decode
(
self
,
tokens
:
List
[
int
])
->
str
:
r
eturn
self
.
tokenizer
.
decode
(
tokens
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
raise
NotImplementedError
(
"No support for logits."
)
...
...
@@ -92,20 +116,31 @@ class AnthropicLM(LM):
res
=
[]
for
request
in
tqdm
(
requests
):
inp
=
request
[
0
]
request_args
=
request
[
1
]
until
=
request_args
[
"until"
]
response
=
anthropic_completion
(
client
=
self
.
client
,
model
=
self
.
model
,
prompt
=
inp
,
max_tokens_to_sample
=
self
.
max_gen_toks
,
temperature
=
0.0
,
# TODO: implement non-greedy sampling for Anthropic
stop
=
until
,
)
res
.
append
(
response
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
request
,
response
)
try
:
inp
=
request
[
0
]
request_args
=
request
[
1
]
# generation_kwargs
until
=
request_args
.
get
(
"until"
)
max_gen_toks
=
request_args
.
get
(
"max_gen_toks"
,
self
.
max_length
)
temperature
=
request_args
.
get
(
"temperature"
,
self
.
temperature
)
response
=
anthropic_completion
(
client
=
self
.
client
,
model
=
self
.
model
,
prompt
=
inp
,
max_tokens_to_sample
=
max_gen_toks
,
temperature
=
temperature
,
# TODO: implement non-greedy sampling for Anthropic
stop
=
until
,
**
self
.
kwargs
,
)
res
.
append
(
response
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
request
,
response
)
except
anthropic
.
APIConnectionError
as
e
:
eval_logger
.
critical
(
f
"Server unreachable:
{
e
.
__cause__
}
"
)
break
except
anthropic
.
APIStatusError
as
e
:
eval_logger
.
critical
(
f
"API error
{
e
.
status_code
}
:
{
e
.
message
}
"
)
break
return
res
...
...
@@ -116,3 +151,9 @@ class AnthropicLM(LM):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
# Isn't used because we override greedy_until
raise
NotImplementedError
()
def
loglikelihood
(
self
,
requests
):
raise
NotImplementedError
(
"No support for logits."
)
def
loglikelihood_rolling
(
self
,
requests
):
raise
NotImplementedError
(
"No support for logits."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment