"...targets/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "695d873e59f3b35ab316284bb13af4433f1f9715"
Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
import random import random
import itertools import itertools
import json
import collections import collections
import sys
import torch import torch
...@@ -17,8 +15,6 @@ import lm_eval.api.registry ...@@ -17,8 +15,6 @@ import lm_eval.api.registry
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated, positional_deprecated,
run_task_tests, run_task_tests,
make_table,
create_iterator,
get_git_commit_hash, get_git_commit_hash,
simple_parse_args_string, simple_parse_args_string,
eval_logger, eval_logger,
...@@ -91,7 +87,7 @@ def simple_evaluate( ...@@ -91,7 +87,7 @@ def simple_evaluate(
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks." "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -118,7 +114,9 @@ def simple_evaluate( ...@@ -118,7 +114,9 @@ def simple_evaluate(
use_cache use_cache
# each rank receives a different cache db. # each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once # necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db", + "_rank"
+ str(lm.rank)
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
...@@ -234,9 +232,6 @@ def evaluate( ...@@ -234,9 +232,6 @@ def evaluate(
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering # store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task # store num-fewshot value per task
num_fewshot = collections.defaultdict(int) num_fewshot = collections.defaultdict(int)
...@@ -264,14 +259,14 @@ def evaluate( ...@@ -264,14 +259,14 @@ def evaluate(
num_fewshot[task_name] = n_shot num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"] results[task_name]["alias"] = configs[task_name]["task_alias"]
if ( if (
("group_alias" in configs[task_name]) ("group_alias" in configs[task_name])
and (group_name not in task_group_alias) and (group_name not in results)
and (group_name is not None) and (group_name is not None)
): ):
task_group_alias[group_name] = configs[task_name]["group_alias"] results[group_name]["alias"] = configs[task_name]["group_alias"]
if limit is not None: if limit is not None:
if task.has_test_docs(): if task.has_test_docs():
...@@ -440,32 +435,6 @@ def evaluate( ...@@ -440,32 +435,6 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
...@@ -505,7 +474,10 @@ def evaluate( ...@@ -505,7 +474,10 @@ def evaluate(
total_size = 0 total_size = 0
for task in task_list: for task in task_list:
metrics = results[task] metrics = results[task].copy()
if "alias" in metrics:
metrics.pop("alias")
current_size = metrics.pop("samples") current_size = metrics.pop("samples")
# TODO: There should be a way for users # TODO: There should be a way for users
...@@ -564,71 +536,77 @@ def evaluate( ...@@ -564,71 +536,77 @@ def evaluate(
results[group]["samples"] = total_size results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias): def print_tasks(task_hierarchy, results, tab=0):
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0): (group_name, task_list), *_ = task_hierarchy.items()
groups_agg[group_name] = results[group_name].copy() task_list = sorted(task_list)
groups_agg[group_name]["tab"] = order
if task_list != []: results_agg[group_name] = results[group_name].copy()
for task in sorted(task_list): # results_agg[group_name]["tab"] = tab
if task in task_hierarchy: if "samples" in results_agg[group_name]:
_task_hierarchy = {task: task_hierarchy[task]} results_agg[group_name].pop("samples")
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks(
_task_hierarchy, task_order, task_version, task_group_alias
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version
results_agg, groups_agg, versions = print_tasks(
task_hierarchy, task_order, versions, task_group_alias
)
for task in results_agg: tab_string = " " * tab + "- " if tab > 0 else ""
task_results = results_agg[task]
if "samples" in task_results: if "alias" in results_agg[group_name]:
task_results.pop("samples") results_agg[group_name]["alias"] = (
tab_string + results_agg[group_name]["alias"]
tab_string = "" )
if "tab" in task_results:
tab = task_results.pop("tab")
tab_string = " " * tab + "- " if tab > 0 else ""
if task in task_group_alias:
task_alias = task_group_alias[task]
results_agg[task]["alias"] = tab_string + task_alias
else: else:
results_agg[task]["alias"] = tab_string + task results_agg[group_name]["alias"] = tab_string + group_name
for group in groups_agg:
group_results = groups_agg[group]
if "samples" in group_results:
group_results.pop("samples")
tab_string = "" if len(task_list) > 0:
if "tab" in group_results: groups_agg[group_name] = results[group_name].copy()
tab = group_results.pop("tab") # groups_agg[group_name]["tab"] = tab
tab_string = " " * tab + "- " if tab > 0 else "" if "samples" in groups_agg[group_name]:
groups_agg[group_name].pop("samples")
if group in task_group_alias: if "alias" in groups_agg[group_name]:
group_alias = task_group_alias[group] groups_agg[group_name]["alias"] = (
groups_agg[group]["alias"] = tab_string + group_alias tab_string + groups_agg[group_name]["alias"]
else: )
groups_agg[group]["alias"] = tab_string + group else:
groups_agg[group_name]["alias"] = tab_string + group_name
for task_name in task_list:
if task_name in task_hierarchy:
_task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else:
_task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
_results_agg, _groups_agg = print_tasks(
_task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = print_tasks(_task_hierarchy, results)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
if task_list != []: if task_list != []:
......
...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components): ...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components):
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for function, kwargs in components:
if kwargs is None: if kwargs is None:
f = get_filter(function)() f = get_filter(function)()
else: else:
......
...@@ -5,5 +5,6 @@ from . import dummy ...@@ -5,5 +5,6 @@ from . import dummy
from . import anthropic_llms from . import anthropic_llms
from . import gguf from . import gguf
from . import vllm_causallms from . import vllm_causallms
from . import mamba_lm
# TODO: implement __all__ # TODO: implement __all__
from lm_eval.api.model import LM from typing import Any, List, Tuple
from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time
from lm_eval import utils from lm_eval import utils
from typing import List, Any, Tuple from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -45,26 +48,30 @@ def anthropic_completion( ...@@ -45,26 +48,30 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`", please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
) )
backoff_time: float = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True: eval_logger.warning(
try: f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
response = client.completions.create( )
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, @retry_on_specific_exceptions(
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences on_exceptions=[anthropic.RateLimitError],
# (e.g. gsm8k's ":") may truncate a lot of the input. max_retries=None, # retry forever, consider changing
stop_sequences=[anthropic.HUMAN_PROMPT] + stop, on_exception_callback=_exception_callback,
max_tokens_to_sample=max_tokens_to_sample, )
temperature=temperature, def completion():
**kwargs, response = client.completions.create(
) prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
return response.completion model=model,
except anthropic.RateLimitError as e: # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
eval_logger.warning( # (e.g. gsm8k's ":") may truncate a lot of the input.
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds" stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
) max_tokens_to_sample=max_tokens_to_sample,
time.sleep(backoff_time) temperature=temperature,
backoff_time *= 1.5 **kwargs,
)
return response.completion
return completion()
@register_model("anthropic") @register_model("anthropic")
...@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
if not requests: if not requests:
return [] return []
......
import random import random
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
......
import requests
import logging import logging
import time import time
from tqdm import tqdm
import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
This diff is collapsed.
from typing import Optional, Union
import torch
from lm_eval import utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
@register_model("mamba_ssm")
class MambaLMWrapper(HFLM):
def __init__(
self,
pretrained="state-spaces/mamba-130m",
**kwargs,
) -> None:
"""
Mamba (via the `mamba_ssm` package) supports the following args:
```
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
```
See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
The above can all be passed via `--model_args` or to this __init__() directly
but we recommend placing many of these within the config.json file uploaded alongside your
Mamba model to the HF Hub instead.
All other HuggingFace from_pretrained() kwargs
such as those related to
`parallelize=True`, PEFT, autoGPTQ,
or any sub-configurations of these advanced args,
are unsupported by the `mamba_ssm` package.
The HFLM arguments
`backend`, `revision`, `subfolder`, `tokenizer`, `truncation`, `max_length`,
`device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
Are all supported by Mamba where they do not conflict
with Mamba-specific restrictions such as causal LMs only.
"""
if "backend" in kwargs:
# mamba currently only supports causal models
assert kwargs["backend"] == "causal"
super().__init__(
pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc
backend=kwargs.get("backend", "causal"),
tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"),
max_length=kwargs.get("max_length", 2048),
**kwargs,
)
def _get_config(
self,
pretrained: str,
**kwargs,
) -> None:
try:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811
except ModuleNotFoundError:
raise Exception(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._config = load_config_hf(pretrained)
def _create_model(
self,
pretrained: str,
dtype: Optional[Union[str, torch.dtype]] = "float16",
# no `parallelize=True` options
# no PEFT and quantization options
# Mamba does not support arbitrary HF from_pretrained() args
**kwargs,
) -> None:
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811
except ModuleNotFoundError:
raise Exception(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype),
**kwargs,
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
for key in ("do_sample", "attention_mask"):
if key in generation_kwargs:
generation_kwargs.pop(key)
# mamba's custom GenerationMixin currently does not support
# passing stopping criteria.
# for the time being, we simply generate to max length,
# then truncate (equivalent result)
# -- this should be revisited to speed up generation
# stopping_criteria = stop_sequences_criteria(
# self.tokenizer, stop, 1, context.shape[0]
# )
return self.model.generate(
input_ids=context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**generation_kwargs,
)
import os
import time
from typing import List, Tuple, Optional
import copy import copy
import os
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec
from typing import List, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -44,24 +45,28 @@ def oa_completion(**kwargs): ...@@ -44,24 +45,28 @@ def oa_completion(**kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
backoff_time = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True: import traceback
try:
return openai.completions.create(**kwargs) traceback.print_exc()
except openai.OpenAIError:
import traceback
traceback.print_exc() @retry_on_specific_exceptions(
time.sleep(backoff_time) on_exceptions=[openai.OpenAIError],
backoff_time *= 1.5 max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return openai.completions.create(**kwargs)
return completion()
@register_model("openai-completions") @register_model("openai-completions")
...@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM): ...@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM):
def __init__( def __init__(
self, self,
model: str = "text-davinci-003", model: str,
truncate: bool = False, truncate: bool = False,
max_gen_toks: int = 256, max_gen_toks: int = 256,
batch_size: int = 1, batch_size: int = 1,
...@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM): ...@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM):
""" """
:param engine: str :param engine: str
OpenAI API engine (e.g. davinci) OpenAI API engine (e.g. gpt-3.5-turbo-instruct)
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
self.seed = seed self.seed = seed
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
import tiktoken
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
...@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM): ...@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
self._max_length = max_length self._max_length = max_length
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_KEY
openai.api_key = os.environ["OPENAI_API_KEY"] openai.api_key = os.environ["OPENAI_API_KEY"]
@property @property
...@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM): ...@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM):
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM): ...@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
...@@ -326,68 +334,68 @@ def oa_chat_completion(client, **kwargs): ...@@ -326,68 +334,68 @@ def oa_chat_completion(client, **kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
async def _get_completions(**kwargs): def _exception_callback(e: Exception, sleep_time: float) -> None:
chat_completions = await client.chat.completions.create(**kwargs) import traceback
return chat_completions
backoff_time = 3 traceback.print_exc()
while True:
try: @retry_on_specific_exceptions(
return client.chat.completions.create(**kwargs) on_exceptions=[openai.OpenAIError],
except openai.OpenAIError: max_retries=None, # retry forever, consider changing
import traceback on_exception_callback=_exception_callback,
)
def completion():
return client.chat.completions.create(**kwargs)
traceback.print_exc() return completion()
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("openai-chat-completions") @register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM): class OpenaiChatCompletionsLM(LM):
def __init__( def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1 self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None,
truncate: bool = False,
**kwargs,
) -> None: ) -> None:
""" """
:param model: str :param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo) OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
self.model = model self.model = model
self.frequency_penalty = 0 self.base_url = base_url
self.logit_bias = None
self.n = 1
self.presence_penalty = 0
self.temperature = 1
self.top_p = 1
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_KEY # Read from environment variable OPENAI_API_KEY
self.client = openai.OpenAI() # openai.AsyncOpenAI() # Set to EMPTY for local
if self.base_url:
@property self.client = openai.OpenAI(base_url=self.base_url)
def eot_token_id(self): else:
return self.end_of_text_token_id self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property @property
def max_length(self) -> int: def max_length(self) -> int:
...@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM): ...@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1])) grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items(): for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
def sameuntil_chunks(xs, size): )
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
...@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM): ...@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM):
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys(): if "until" in kwargs.keys():
until = kwargs.pop("until") until = kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
) )
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}" f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
) )
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion( response = oa_chat_completion(
client=self.client, client=self.client, messages=inps, model=self.model, **kwargs
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
......
...@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html ...@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html
""" """
import logging import logging
import os import os
import requests as _requests import requests as _requests
import time
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs): ...@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs):
"""Query TextSynth API for completion. """Query TextSynth API for completion.
Retry with back-off until they respond. Retry with back-off until they respond.
""" """
backoff_time = 3
while True:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback
traceback.print_exc() def _exception_callback(e: Exception, sleep_time: float) -> None:
time.sleep(backoff_time) import traceback
backoff_time *= 1.5
traceback.print_exc()
@retry_on_specific_exceptions(
on_exceptions=[_requests.exceptions.RequestException],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return _requests.post(**kwargs)
return completion()
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate: bool = False) -> None: def __init__(self, engine, truncate: bool = False, **kwargs) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -149,7 +156,7 @@ class TextSynthLM(LM): ...@@ -149,7 +156,7 @@ class TextSynthLM(LM):
self.cache_hook.add_partial("generate_until", (inp, request_args), s) self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " "The following response does not contain generated `text`. "
"Got:\n{resp}" "Got:\n{resp}"
) )
assert False assert False
......
from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union, Any
from transformers import AutoTokenizer
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy import copy
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval.utils import (
Collator,
divide,
eval_logger,
get_rolling_token_windows,
make_disjoint_window,
)
try: try:
from vllm import LLM, SamplingParams import ray
from ray.util.multiprocessing import Pool from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
eval_logger = utils.eval_logger eval_logger = eval_logger
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727 # adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]): def run_inference_one_model(
# gpu_id = [x for x in gpu_id] model_args: dict, sampling_params, requests: List[List[int]]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id) ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
...@@ -40,7 +49,7 @@ class VLLM(LM): ...@@ -40,7 +49,7 @@ class VLLM(LM):
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[Literal["awq"]] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: Union[str, int] = 1,
...@@ -54,12 +63,10 @@ class VLLM(LM): ...@@ -54,12 +63,10 @@ class VLLM(LM):
): ):
super().__init__() super().__init__()
try: if not find_spec("vllm"):
import vllm
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \ "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`", "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
...@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
} }
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
else batch_size
)
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else: else:
self.model_args["worker_use_ray"] = True self.model_args["worker_use_ray"] = True
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
from transformers import AutoConfig
self._config = AutoConfig.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, revision=revision
)
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
tokenizer if tokenizer else pretrained, tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.batch_size = batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
@property @property
...@@ -107,9 +127,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -107,9 +127,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
return self._max_length return self._max_length
if hasattr(self.tokenizer, "model_max_length"): if self.data_parallel_size <= 1:
return self.tokenizer.model_max_length return self.model.llm_engine.model_config.max_model_len
return self._DEFAULT_MAX_LENGTH else:
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self):
...@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
requests = [ requests = [list(x) for x in divide(requests, self.data_parallel_size)]
list(x) for x in utils.divide(requests, self.data_parallel_size)
]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel_size) as pool: with Pool(self.data_parallel_size) as pool:
results = pool.starmap(run_inference_one_model, inputs) results = pool.starmap(run_inference_one_model, inputs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
...@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
) )
return outputs return outputs
def _encode_pair( def _encode_pair(
...@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, make_disjoint_window,
utils.get_rolling_token_windows( get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length - 1, max_seq_len=self.max_length - 1,
...@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
return loglikelihoods return loglikelihoods
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = []
re_ords = {}
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
...@@ -250,84 +278,73 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -250,84 +278,73 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), tuple(_requests[0][1]) return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x[1])) re_ords = Collator(requests, _collate_gen, grouping=True)
for key, reqs in grouper.get_grouped().items(): chunks = re_ords.get_batched(
# within each set of reqs for given kwargs, we reorder by token length, descending. n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
re_ords[key] = utils.Reorderer(requests, _collate_gen) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): for chunk in chunks:
chunks = utils.chunks( context_and_encoding, all_gen_kwargs = zip(*chunk)
re_ord.get_reordered(), context, context_encoding = zip(*context_and_encoding)
n=int(self.batch_size) if self.batch_size != "auto" else 0, # we assume all gen kwargs in the batch are the same
fn=None, # this is safe to assume because the `grouper` object ensures it.
) gen_kwargs = all_gen_kwargs[0]
for chunk in chunks: # unpack our keyword arguments.
context_and_encoding, all_gen_kwargs = zip(*chunk) until = None
context, context_encoding = zip(*context_and_encoding) if isinstance(gen_kwargs, dict):
# we assume all gen kwargs in the batch are the same kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# this is safe to assume because the `grouper` object ensures it. if "until" in kwargs.keys():
gen_kwargs = all_gen_kwargs[0] until = kwargs.pop("until")
# unpack our keyword arguments. if isinstance(until, str):
until = None until = [until]
if isinstance(gen_kwargs, dict): elif not isinstance(until, list):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 raise ValueError(
if "until" in kwargs.keys(): f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
until = kwargs.pop("until") )
if isinstance(until, str): else:
until = [until] raise ValueError(
elif not isinstance(until, list): f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {gen_kwargs}"
)
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# TODO: max_length in kwargs
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
) )
if not until:
until = [self.tokenizer.decode(self.eot_token_id)]
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
)
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
res[key].append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close() pbar.close()
# reorder all group of results back to original unsorted form
return grouper.get_original(res) return re_ords.get_original(res)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
...@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) # Reorder requests by length and batch
re_ord = Collator(requests, sort_fn=_collate)
chunks = utils.chunks( chunks = re_ord.get_batched(
re_ord.get_reordered(), n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
n=int(self.batch_size) if self.batch_size != "auto" else 0,
fn=None,
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks: for chunk in chunks:
inps = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-(self.max_length) :]
...@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - (self.max_length)
) )
inps.append(inp) inputs.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
outputs = self._model_generate(requests=inps, generate=False) outputs = self._model_generate(requests=inputs, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip( for output, ctxlen, (cache_key, _, _), inp in zip(
outputs, ctxlens, chunk outputs, ctxlens, chunk, inputs
): ):
answer = self._parse_logprobs( answer = self._parse_logprobs(
(context_enc + continuation_enc), tokens=inp,
output, outputs=output,
ctxlen, ctxlen=ctxlen,
) )
res.append(answer) res.append(answer)
...@@ -376,7 +392,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -376,7 +392,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# partial caching # partial caching
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
pbar.update(1) pbar.update(1)
pbar.close() pbar.close()
return re_ord.get_original(res) return re_ord.get_original(res)
...@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
Tokens from context+continuations Input tokens (potentially left-truncated)
:param outputs: RequestOutput :param outputs: RequestOutput
Contains prompt Contains prompt_logprobs
:param ctxlen: int :param ctxlen: int
Length of context (so we can slice them away and only keep the predictions) Length of context (so we can slice them away and only keep the predictions)
:return: :return:
...@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Whether argmax matches given continuation exactly Whether argmax matches given continuation exactly
""" """
# prompt_logprobs = [None, {}*len(context-1)] # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
continuation_logprobs_dicts = outputs.prompt_logprobs continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs # Calculate continuation_logprobs
# assume ctxlen always > 1 # assume ctxlen always >= 1
continuation_logprobs = sum( continuation_logprobs = sum(
logprob_dict.get(token) logprob_dict.get(token)
for token, logprob_dict in zip( for token, logprob_dict in zip(
......
...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None ...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
def load_prompt_list( def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
): ):
category_name, prompt_name = use_prompt.split(":") category_name, prompt_name = use_prompt.split(":")
if category_name == "promptsource": if category_name == "promptsource":
...@@ -113,7 +112,6 @@ class PromptString: ...@@ -113,7 +112,6 @@ class PromptString:
self.prompt_string = prompt_string self.prompt_string = prompt_string
def apply(self, doc): def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"] doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"] doc_to_target = self.prompt_string["doc_to_target"]
......
...@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list = [task for task in all_task_list if type(task) == str] task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list: for task_config in config_list:
base_config = {}
task_name_config = {}
if "task" in task_config:
task_name = task_config["task"]
if task_name in ALL_TASKS:
task_obj = get_task_dict(task_name)[task_name]
if type(task_obj) == tuple:
_, task_obj = task_obj
if task_obj is not None:
base_config = task_obj._config.to_dict()
task_name_config["task"] = f"{group}_{task_name}"
task_config = utils.load_yaml_config(yaml_path, task_config) task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config( var_configs = check_prompt_config(
{ {
**base_config,
**task_config, **task_config,
**{"group": group}, **{"group": group},
**task_name_config,
}, },
yaml_path=os.path.dirname(yaml_path), yaml_path=os.path.dirname(yaml_path),
) )
...@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None: ...@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
""" """
Calling this function Calling this function
""" """
for root, subdirs, file_list in reversed(list(os.walk(task_dir))):
# Track whether any tasks failed during loading
import_fail = False
for root, subdirs, file_list in os.walk(task_dir):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
...@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None: ...@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
# Log this silently and show it only when # Log this silently and show it only when
# the user defines the appropriate verbosity. # the user defines the appropriate verbosity.
except ModuleNotFoundError as e: except (ImportError, ModuleNotFoundError) as e:
import_fail = True
eval_logger.debug( eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry." f"{yaml_path}: {e}. Config will not be added to registry."
) )
except Exception as error: except Exception as error:
import traceback import traceback
eval_logger.debug( eval_logger.warning(
"Failed to load config in\n" "Unexpected error loading config in\n"
f" {yaml_path}\n" f" {yaml_path}\n"
" Config will not be added to registry\n" " Config will not be added to registry\n"
f" Error: {error}\n" f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}" f" Traceback: {traceback.format_exc()}"
) )
if import_fail:
eval_logger.warning(
"Some tasks could not be loaded due to missing dependencies."
" Run with `--verbosity DEBUG` for full details."
)
return 0 return 0
...@@ -180,7 +206,6 @@ def include_path(task_dir): ...@@ -180,7 +206,6 @@ def include_path(task_dir):
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
......
...@@ -23,4 +23,4 @@ metric_list: ...@@ -23,4 +23,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -20,4 +20,4 @@ metric_list: ...@@ -20,4 +20,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -13,4 +13,4 @@ metric_list: ...@@ -13,4 +13,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -11,4 +11,4 @@ metric_list: ...@@ -11,4 +11,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -24,7 +24,6 @@ def parse_args(): ...@@ -24,7 +24,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -37,7 +36,6 @@ if __name__ == "__main__": ...@@ -37,7 +36,6 @@ if __name__ == "__main__":
dataset_path = "lukaemon/bbh" dataset_path = "lukaemon/bbh"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()): for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
resp = requests.get( resp = requests.get(
f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt" f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
).content.decode("utf-8") ).content.decode("utf-8")
......
...@@ -27,4 +27,4 @@ filter_list: ...@@ -27,4 +27,4 @@ filter_list:
- function: "take_first" - function: "take_first"
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
- version: 1.0 version: 2.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment