Unverified Commit 0a0f879c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #724 from EleutherAI/Refactor]-fix-anthropic-import-bug

[Refactor] Fix Anthropic Import
parents ca36c41c e49972b0
......@@ -51,7 +51,7 @@ pip install -e ".[gptq]"
## Support
The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for recieving support for our releases.
The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases.
## Basic Usage
......
......@@ -114,7 +114,12 @@ def simple_evaluate(
task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
config = task_dict[task_name]._config
task_obj = task_dict[task_name]
if type(task_obj) == tuple:
group, task_obj = task_obj
config = task_obj._config
if num_fewshot is not None:
if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"]
......@@ -122,7 +127,7 @@ def simple_evaluate(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_dict[task_name]._config["num_fewshot"] = num_fewshot
task_obj._config["num_fewshot"] = num_fewshot
if check_integrity:
run_task_tests(task_list=tasks)
......
from . import huggingface
from . import openai_completions
from . import anthropic_llms
from . import textsynth
from . import dummy
from . import anthropic_llms
# TODO: implement __all__
......@@ -3,13 +3,12 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm
import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal, Any
def anthropic_completion(
client: anthropic.Anthropic,
client, #: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
......@@ -21,6 +20,15 @@ def anthropic_completion(
Retry with back-off until they respond
"""
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
backoff_time = 3
while True:
try:
......@@ -68,6 +76,14 @@ class AnthropicLM(LM):
"""
super().__init__()
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
self.model = model
# defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
......@@ -135,10 +151,10 @@ class AnthropicLM(LM):
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e:
except anthropic.APIConnectionError as e: # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
except anthropic.APIStatusError as e: # noqa: F821
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
......
......@@ -361,23 +361,29 @@ class HFLM(LM):
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
else:
max_length = self.max_length
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones((batch_size, length), device=self.device).long()
batched_conts = torch.ones(
(batch_size, length), device=self.device
).long()
test_batch = torch.ones((batch_size, length), device=self.device).long()
call_kwargs = {
"attn_mask": test_batch,
"labels": batched_conts,
}
"attn_mask": test_batch,
"labels": batched_conts,
}
else:
call_kwargs = {}
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
test_batch = torch.ones(
(batch_size, max_length), device=self.device
).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)
out = out # Identity process so that it passes pre-commit
return batch_size
batch_size = forward_batch()
......@@ -391,12 +397,10 @@ class HFLM(LM):
batch_size = min(gathered)
utils.clear_torch_cache()
return batch_size
utils.clear_torch_cache()
return batch_size
def tok_encode(self, string: str, left_truncate_len=None):
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -573,7 +577,9 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size
rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
)
if (self.world_size > 1) and (pad_amnt > 0):
......@@ -601,26 +607,31 @@ class HFLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate)
n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched-1] == self.max_batch_size):
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(re_ord.get_reordered(), pos)
self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
......@@ -630,7 +641,9 @@ class HFLM(LM):
if override_bs is not None
else 0,
fn=_batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None,
):
inps = []
......
......@@ -32,7 +32,7 @@ def parse_args():
default=None,
help="Number of examples in few-shot context",
)
parser.add_argument("--batch_size", type=str, default=1)
parser.add_argument("--batch_size", type=str, default=1)
parser.add_argument(
"--max_batch_size",
type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment