Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
import random import random
import itertools import itertools
import json
import collections import collections
import sys
import torch import torch
...@@ -17,8 +15,6 @@ import lm_eval.api.registry ...@@ -17,8 +15,6 @@ import lm_eval.api.registry
from lm_eval.utils import ( from lm_eval.utils import (
positional_deprecated, positional_deprecated,
run_task_tests, run_task_tests,
make_table,
create_iterator,
get_git_commit_hash, get_git_commit_hash,
simple_parse_args_string, simple_parse_args_string,
eval_logger, eval_logger,
...@@ -91,7 +87,7 @@ def simple_evaluate( ...@@ -91,7 +87,7 @@ def simple_evaluate(
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks." "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -118,7 +114,9 @@ def simple_evaluate( ...@@ -118,7 +114,9 @@ def simple_evaluate(
use_cache use_cache
# each rank receives a different cache db. # each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once # necessary to avoid multiple writes to cache at once
+ "_rank" + str(lm.rank) + ".db", + "_rank"
+ str(lm.rank)
+ ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
...@@ -234,9 +232,6 @@ def evaluate( ...@@ -234,9 +232,6 @@ def evaluate(
padding_requests = collections.defaultdict(int) padding_requests = collections.defaultdict(int)
# store the hierarchy to do proper ordering # store the hierarchy to do proper ordering
task_hierarchy = collections.defaultdict(list) task_hierarchy = collections.defaultdict(list)
# store the ordering of tasks and groups
task_order = collections.defaultdict(int)
task_group_alias = collections.defaultdict(dict)
# store num-fewshot value per task # store num-fewshot value per task
num_fewshot = collections.defaultdict(int) num_fewshot = collections.defaultdict(int)
...@@ -264,14 +259,14 @@ def evaluate( ...@@ -264,14 +259,14 @@ def evaluate(
num_fewshot[task_name] = n_shot num_fewshot[task_name] = n_shot
if "task_alias" in configs[task_name]: if "task_alias" in configs[task_name]:
task_group_alias[task_name] = configs[task_name]["task_alias"] results[task_name]["alias"] = configs[task_name]["task_alias"]
if ( if (
("group_alias" in configs[task_name]) ("group_alias" in configs[task_name])
and (group_name not in task_group_alias) and (group_name not in results)
and (group_name is not None) and (group_name is not None)
): ):
task_group_alias[group_name] = configs[task_name]["group_alias"] results[group_name]["alias"] = configs[task_name]["group_alias"]
if limit is not None: if limit is not None:
if task.has_test_docs(): if task.has_test_docs():
...@@ -440,32 +435,6 @@ def evaluate( ...@@ -440,32 +435,6 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Get task ordering for correct sample-wide aggregation
group_to_task = {}
for group in task_hierarchy.keys():
if group not in task_order:
task_order[group] = 0
if len(task_hierarchy[group]) > 0:
group_to_task[group] = task_hierarchy[group].copy()
for task in task_hierarchy[group]:
if task in task_order:
task_order[task] += 1
else:
task_order[task] = 1 + task_order[group]
if task in task_hierarchy:
group_to_task[group].remove(task)
group_to_task[group].extend(task_hierarchy[task])
task_to_group = {}
for group in group_to_task:
for task in group_to_task[group]:
if task in task_to_group:
task_to_group[task].append(group)
else:
task_to_group[task] = [group]
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
...@@ -505,7 +474,10 @@ def evaluate( ...@@ -505,7 +474,10 @@ def evaluate(
total_size = 0 total_size = 0
for task in task_list: for task in task_list:
metrics = results[task] metrics = results[task].copy()
if "alias" in metrics:
metrics.pop("alias")
current_size = metrics.pop("samples") current_size = metrics.pop("samples")
# TODO: There should be a way for users # TODO: There should be a way for users
...@@ -564,71 +536,77 @@ def evaluate( ...@@ -564,71 +536,77 @@ def evaluate(
results[group]["samples"] = total_size results[group]["samples"] = total_size
def print_tasks(task_hierarchy, task_order, task_version, task_group_alias): def print_tasks(task_hierarchy, results, tab=0):
results_agg = collections.defaultdict(dict) results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = collections.defaultdict(dict)
for group_name, task_list in task_hierarchy.items():
order = task_order[group_name]
results_agg[group_name] = results[group_name].copy()
results_agg[group_name]["tab"] = order
if (order < max(task_order.values())) and (len(task_list) > 0):
groups_agg[group_name] = results[group_name].copy()
groups_agg[group_name]["tab"] = order
if task_list != []: (group_name, task_list), *_ = task_hierarchy.items()
for task in sorted(task_list): task_list = sorted(task_list)
if task in task_hierarchy:
_task_hierarchy = {task: task_hierarchy[task]}
else:
_task_hierarchy = {task: []}
_results_agg, _groups_agg, task_version = print_tasks( results_agg[group_name] = results[group_name].copy()
_task_hierarchy, task_order, task_version, task_group_alias # results_agg[group_name]["tab"] = tab
) if "samples" in results_agg[group_name]:
results_agg[group_name].pop("samples")
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
return results_agg, groups_agg, task_version tab_string = " " * tab + "- " if tab > 0 else ""
results_agg, groups_agg, versions = print_tasks( if "alias" in results_agg[group_name]:
task_hierarchy, task_order, versions, task_group_alias results_agg[group_name]["alias"] = (
tab_string + results_agg[group_name]["alias"]
) )
else:
results_agg[group_name]["alias"] = tab_string + group_name
for task in results_agg: if len(task_list) > 0:
task_results = results_agg[task] groups_agg[group_name] = results[group_name].copy()
# groups_agg[group_name]["tab"] = tab
if "samples" in task_results: if "samples" in groups_agg[group_name]:
task_results.pop("samples") groups_agg[group_name].pop("samples")
tab_string = "" if "alias" in groups_agg[group_name]:
if "tab" in task_results: groups_agg[group_name]["alias"] = (
tab = task_results.pop("tab") tab_string + groups_agg[group_name]["alias"]
tab_string = " " * tab + "- " if tab > 0 else "" )
else:
groups_agg[group_name]["alias"] = tab_string + group_name
if task in task_group_alias: for task_name in task_list:
task_alias = task_group_alias[task] if task_name in task_hierarchy:
results_agg[task]["alias"] = tab_string + task_alias _task_hierarchy = {
**{task_name: task_hierarchy[task_name]},
**task_hierarchy,
}
else: else:
results_agg[task]["alias"] = tab_string + task _task_hierarchy = {
**{task_name: []},
**task_hierarchy,
}
for group in groups_agg: _results_agg, _groups_agg = print_tasks(
group_results = groups_agg[group] _task_hierarchy, results, tab + 1
)
results_agg = {**results_agg, **_results_agg}
groups_agg = {**groups_agg, **_groups_agg}
if "samples" in group_results: return results_agg, groups_agg
group_results.pop("samples")
tab_string = "" results_agg = collections.defaultdict(dict)
if "tab" in group_results: groups_agg = collections.defaultdict(dict)
tab = group_results.pop("tab") all_tasks_list = list(task_hierarchy.keys())
tab_string = " " * tab + "- " if tab > 0 else "" left_tasks_list = []
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
if len(left_tasks_list) == 0:
break
_task_hierarchy = {
k: v for k, v in task_hierarchy.items() if k in left_tasks_list
}
_results_agg, _groups_agg = print_tasks(_task_hierarchy, results)
if group in task_group_alias: results_agg = {**results_agg, **_results_agg}
group_alias = task_group_alias[group] groups_agg = {**groups_agg, **_groups_agg}
groups_agg[group]["alias"] = tab_string + group_alias
else:
groups_agg[group]["alias"] = tab_string + group
for group_name, task_list in task_hierarchy.items(): for group_name, task_list in task_hierarchy.items():
if task_list != []: if task_list != []:
......
...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components): ...@@ -32,7 +32,7 @@ def build_filter_ensemble(filter_name, components):
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for (function, kwargs) in components: for function, kwargs in components:
if kwargs is None: if kwargs is None:
f = get_filter(function)() f = get_filter(function)()
else: else:
......
...@@ -5,5 +5,6 @@ from . import dummy ...@@ -5,5 +5,6 @@ from . import dummy
from . import anthropic_llms from . import anthropic_llms
from . import gguf from . import gguf
from . import vllm_causallms from . import vllm_causallms
from . import mamba_lm
# TODO: implement __all__ # TODO: implement __all__
from lm_eval.api.model import LM from typing import Any, List, Tuple
from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time
from lm_eval import utils from lm_eval import utils
from typing import List, Any, Tuple from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -45,9 +48,17 @@ def anthropic_completion( ...@@ -45,9 +48,17 @@ def anthropic_completion(
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`", please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
) )
backoff_time: float = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True: eval_logger.warning(
try: f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
)
@retry_on_specific_exceptions(
on_exceptions=[anthropic.RateLimitError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
response = client.completions.create( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
...@@ -59,12 +70,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -59,12 +70,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
**kwargs, **kwargs,
) )
return response.completion return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning( return completion()
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
)
time.sleep(backoff_time)
backoff_time *= 1.5
@register_model("anthropic") @register_model("anthropic")
...@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -141,6 +148,14 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
raise Exception(
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e .[anthropic]`",
)
if not requests: if not requests:
return [] return []
......
import random import random
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
......
import requests
import logging import logging
import time import time
from tqdm import tqdm
import requests
from requests.exceptions import RequestException from requests.exceptions import RequestException
from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
import copy
import os import os
from packaging import version from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import transformers import transformers
from accelerate import Accelerator, DistributedType, find_executable_batch_size
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
) )
from peft import __version__ as PEFT_VERSION, PeftModel
import copy
from collections import defaultdict
from tqdm import tqdm
from pathlib import Path
import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import Collator, stop_sequences_criteria
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union, Tuple, Literal
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -107,9 +105,7 @@ class HFLM(LM): ...@@ -107,9 +105,7 @@ class HFLM(LM):
eval_logger.warning( eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
) )
assert ( assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
not parallelize
), "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._model = pretrained self._model = pretrained
self._device = self._model.device self._device = self._model.device
...@@ -137,6 +133,8 @@ class HFLM(LM): ...@@ -137,6 +133,8 @@ class HFLM(LM):
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator = Accelerator() accelerator = Accelerator()
if accelerator.num_processes > 1:
self.accelerator = accelerator
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
...@@ -145,9 +143,7 @@ class HFLM(LM): ...@@ -145,9 +143,7 @@ class HFLM(LM):
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())] + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"] + ["mps", "mps:0"]
) )
if device: if device and device in device_list:
if device not in device_list:
device = int(device)
self._device = torch.device(device) self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'") eval_logger.info(f"Using device '{device}'")
if device in ("mps", "mps:0") and version.parse( if device in ("mps", "mps:0") and version.parse(
...@@ -170,7 +166,7 @@ class HFLM(LM): ...@@ -170,7 +166,7 @@ class HFLM(LM):
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
) )
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device self._device = torch.device(device)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
...@@ -207,16 +203,17 @@ class HFLM(LM): ...@@ -207,16 +203,17 @@ class HFLM(LM):
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
if (gpus >= 1 or self.device.type == "mps") and isinstance(pretrained, str): if isinstance(pretrained, str) and (gpus >= 1 or str(self.device) == "mps"):
if not (parallelize or autogptq or ("device_map" in kwargs)): # TODO: can remove this whole snippet except in the mps case, perhaps?
if not (parallelize or autogptq or hasattr(self, "accelerator")):
# place model onto device requested manually, # place model onto device requested manually,
# if not using HF Accelerate or device_map # if not using HF Accelerate or device_map
# or any other option that preloads model onto device # or any other option that preloads model onto device
try: try:
self.model.to(self.device) self.model.to(self.device)
except ValueError: except ValueError:
eval_logger.info( eval_logger.debug(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
) )
self._create_tokenizer( self._create_tokenizer(
...@@ -238,7 +235,7 @@ class HFLM(LM): ...@@ -238,7 +235,7 @@ class HFLM(LM):
elif self.tokenizer.eos_token: elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else: else:
if "Qwen" in pretrained: if self.config.model_type == "qwen":
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens # Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>" self.tokenizer.pad_token = "<|endoftext|>"
else: else:
...@@ -279,10 +276,13 @@ class HFLM(LM): ...@@ -279,10 +276,13 @@ class HFLM(LM):
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices." f"Current run will proceed with {accelerator.num_processes} devices."
) )
assert accelerator.distributed_type in [ assert (
accelerator.distributed_type
in [
DistributedType.FSDP, DistributedType.FSDP,
DistributedType.MULTI_GPU, DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported." ]
), "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP: if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
else: else:
...@@ -417,7 +417,6 @@ class HFLM(LM): ...@@ -417,7 +417,6 @@ class HFLM(LM):
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> None: ) -> None:
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -460,12 +459,24 @@ class HFLM(LM): ...@@ -460,12 +459,24 @@ class HFLM(LM):
if parallelize: if parallelize:
model_kwargs.update( model_kwargs.update(
_get_accelerate_args( _get_accelerate_args(
device_map_option, device_map_option, # TODO: phase out device_map_option?
max_memory_per_gpu, max_memory_per_gpu,
max_cpu_memory, max_cpu_memory,
offload_folder, offload_folder,
) )
) )
elif "device_map" not in model_kwargs:
# set a device_map to initialize model on the right GPU.
# this is needed because it seems that the default behavior
# for quantized models now seems to be device_map="auto"
# which breaks data-parallel mode.
if hasattr(self, "accelerator"):
model_kwargs.update(
{"device_map": {"": f"cuda:{self.accelerator.local_process_index}"}}
)
else:
model_kwargs.update({"device_map": {"": str(self.device)}})
if not autogptq: if not autogptq:
if model_kwargs.get("load_in_4bit", None): if model_kwargs.get("load_in_4bit", None):
assert ( assert (
...@@ -635,7 +646,7 @@ class HFLM(LM): ...@@ -635,7 +646,7 @@ class HFLM(LM):
padding_side: str = "left", padding_side: str = "left",
left_truncate_len: int = None, left_truncate_len: int = None,
truncation: bool = False, truncation: bool = False,
) -> Tuple[List[int], List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
...@@ -700,7 +711,7 @@ class HFLM(LM): ...@@ -700,7 +711,7 @@ class HFLM(LM):
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria # build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0] self.tokenizer, stop, context.shape[1], context.shape[0]
) )
return self.model.generate( return self.model.generate(
input_ids=context, input_ids=context,
...@@ -751,8 +762,9 @@ class HFLM(LM): ...@@ -751,8 +762,9 @@ class HFLM(LM):
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -844,6 +856,7 @@ class HFLM(LM): ...@@ -844,6 +856,7 @@ class HFLM(LM):
res = [] res = []
def _collate(x): def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -854,26 +867,27 @@ class HFLM(LM): ...@@ -854,26 +867,27 @@ class HFLM(LM):
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) re_ord = Collator(requests, sort_fn=_collate)
n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
n_reordered_requests = len(re_ord)
chunks = utils.chunks( batch_size = (
re_ord.get_reordered(), self.batch_size
n=self.batch_size
if self.batch_size != "auto" if self.batch_size != "auto"
else override_bs else override_bs
if override_bs is not None if override_bs is not None
else 0, else 0
fn=self._batch_scheduler )
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" if self.batch_size == "auto"
and n_reordered_requests > 0 and n_reordered_requests > 0
and not override_bs and not override_bs
else None, else None
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
...@@ -995,9 +1009,7 @@ class HFLM(LM): ...@@ -995,9 +1009,7 @@ class HFLM(LM):
greedy_tokens = logits.argmax(dim=-1) greedy_tokens = logits.argmax(dim=-1)
cont_toks = torch.tensor( cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device cont_toks, dtype=torch.long, device=self.device
).unsqueeze( ).unsqueeze(0) # [1, seq]
0
) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all() max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices # Obtain log-probs at the corresponding continuation token indices
...@@ -1019,10 +1031,10 @@ class HFLM(LM): ...@@ -1019,10 +1031,10 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = []
re_ords = {}
def _collate(x): def _collate(x):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning # - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch # - to know the size of a batch when going through the list, you know the first one is always the batch
...@@ -1032,14 +1044,6 @@ class HFLM(LM): ...@@ -1032,14 +1044,6 @@ class HFLM(LM):
toks = self.tok_encode(x[0]) toks = self.tok_encode(x[0])
return -len(toks), x[0] return -len(toks), x[0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -1048,18 +1052,24 @@ class HFLM(LM): ...@@ -1048,18 +1052,24 @@ class HFLM(LM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): batch_size = (
chunks = utils.chunks( self.batch_size
re_ord.get_reordered(),
n=self.batch_size
if self.batch_size != "auto" if self.batch_size != "auto"
else adaptive_batch_size else adaptive_batch_size
if adaptive_batch_size is not None if adaptive_batch_size is not None
else 0, else 0
fn=self._batch_scheduler )
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size if self.batch_size == "auto" and not adaptive_batch_size
else None, else None
) )
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator([reg.args for reg in requests], _collate, grouping=True)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
...@@ -1131,15 +1141,13 @@ class HFLM(LM): ...@@ -1131,15 +1141,13 @@ class HFLM(LM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = '' # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0] s = s.split(term)[0]
res[key].append(s) res.append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
"generate_until", (context, gen_kwargs), s
)
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key]) res = re_ords.get_original(res)
pbar.close() pbar.close()
return grouper.get_original(res) return res
from typing import Optional, Union
import torch
from lm_eval import utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
@register_model("mamba_ssm")
class MambaLMWrapper(HFLM):
def __init__(
self,
pretrained="state-spaces/mamba-130m",
**kwargs,
) -> None:
"""
Mamba (via the `mamba_ssm` package) supports the following args:
```
d_model: int,
n_layer: int,
vocab_size: int,
initializer_cfg=None,
pad_vocab_size_multiple: int = 1,
ssm_cfg=None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
```
See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info.
The above can all be passed via `--model_args` or to this __init__() directly
but we recommend placing many of these within the config.json file uploaded alongside your
Mamba model to the HF Hub instead.
All other HuggingFace from_pretrained() kwargs
such as those related to
`parallelize=True`, PEFT, autoGPTQ,
or any sub-configurations of these advanced args,
are unsupported by the `mamba_ssm` package.
The HFLM arguments
`backend`, `revision`, `subfolder`, `tokenizer`, `truncation`, `max_length`,
`device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer`
Are all supported by Mamba where they do not conflict
with Mamba-specific restrictions such as causal LMs only.
"""
if "backend" in kwargs:
# mamba currently only supports causal models
assert kwargs["backend"] == "causal"
super().__init__(
pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc
backend=kwargs.get("backend", "causal"),
tokenizer=kwargs.get("tokenizer", "EleutherAI/gpt-neox-20b"),
max_length=kwargs.get("max_length", 2048),
**kwargs,
)
def _get_config(
self,
pretrained: str,
**kwargs,
) -> None:
try:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811
except ModuleNotFoundError:
raise Exception(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._config = load_config_hf(pretrained)
def _create_model(
self,
pretrained: str,
dtype: Optional[Union[str, torch.dtype]] = "float16",
# no `parallelize=True` options
# no PEFT and quantization options
# Mamba does not support arbitrary HF from_pretrained() args
**kwargs,
) -> None:
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811
except ModuleNotFoundError:
raise Exception(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16 if dtype == "auto" else utils.get_dtype(dtype),
**kwargs,
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
for key in ("do_sample", "attention_mask"):
if key in generation_kwargs:
generation_kwargs.pop(key)
# mamba's custom GenerationMixin currently does not support
# passing stopping criteria.
# for the time being, we simply generate to max length,
# then truncate (equivalent result)
# -- this should be revisited to speed up generation
# stopping_criteria = stop_sequences_criteria(
# self.tokenizer, stop, 1, context.shape[0]
# )
return self.model.generate(
input_ids=context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**generation_kwargs,
)
import os
import time
from typing import List, Tuple, Optional
import copy import copy
import os
from collections import defaultdict from collections import defaultdict
from importlib.util import find_spec
from typing import List, Optional, Tuple
from tqdm import tqdm from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
def get_result(response, ctxlen: int) -> Tuple[float, bool]: def get_result(response, ctxlen: int) -> Tuple[float, bool]:
...@@ -44,24 +45,28 @@ def oa_completion(**kwargs): ...@@ -44,24 +45,28 @@ def oa_completion(**kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
backoff_time = 3 def _exception_callback(e: Exception, sleep_time: float) -> None:
while True:
try:
return openai.completions.create(**kwargs)
except openai.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return openai.completions.create(**kwargs)
return completion()
@register_model("openai-completions") @register_model("openai-completions")
...@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM): ...@@ -71,7 +76,7 @@ class OpenaiCompletionsLM(LM):
def __init__( def __init__(
self, self,
model: str = "text-davinci-003", model: str,
truncate: bool = False, truncate: bool = False,
max_gen_toks: int = 256, max_gen_toks: int = 256,
batch_size: int = 1, batch_size: int = 1,
...@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM): ...@@ -81,14 +86,15 @@ class OpenaiCompletionsLM(LM):
""" """
:param engine: str :param engine: str
OpenAI API engine (e.g. davinci) OpenAI API engine (e.g. gpt-3.5-turbo-instruct)
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
self.seed = seed self.seed = seed
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
import tiktoken
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
...@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM): ...@@ -102,7 +108,7 @@ class OpenaiCompletionsLM(LM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
self._max_length = max_length self._max_length = max_length
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_KEY
openai.api_key = os.environ["OPENAI_API_KEY"] openai.api_key = os.environ["OPENAI_API_KEY"]
@property @property
...@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM): ...@@ -154,8 +160,9 @@ class OpenaiCompletionsLM(LM):
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM): ...@@ -247,6 +254,7 @@ class OpenaiCompletionsLM(LM):
list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE)) list(sameuntil_chunks(re_ord.get_reordered(), self.REQ_CHUNK_SIZE))
): ):
inps = [] inps = []
self._max_gen_toks = request_args.pop("max_gen_toks", self.max_gen_toks)
for context, _ in chunk: for context, _ in chunk:
context_enc = self.tok_encode(context) context_enc = self.tok_encode(context)
inp = context_enc[-(self.max_length - self.max_gen_toks) :] inp = context_enc[-(self.max_length - self.max_gen_toks) :]
...@@ -326,69 +334,69 @@ def oa_chat_completion(client, **kwargs): ...@@ -326,69 +334,69 @@ def oa_chat_completion(client, **kwargs):
Retry with back-off until they respond Retry with back-off until they respond
""" """
try: if not find_spec("openai") or not find_spec("tiktoken"):
import openai, tiktoken # noqa: E401
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. "
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", "Please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`"
) )
else:
import openai
async def _get_completions(**kwargs): def _exception_callback(e: Exception, sleep_time: float) -> None:
chat_completions = await client.chat.completions.create(**kwargs)
return chat_completions
backoff_time = 3
while True:
try:
return client.chat.completions.create(**kwargs)
except openai.OpenAIError:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[openai.OpenAIError],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return client.chat.completions.create(**kwargs)
return completion()
@register_model("openai-chat-completions") @register_model("openai-chat-completions", "local-chat-completions")
class OpenaiChatCompletionsLM(LM): class OpenaiChatCompletionsLM(LM):
def __init__( def __init__(
self, model: str = "gpt-3.5-turbo", truncate: bool = False, batch_size: int = 1 self,
model: str = "gpt-3.5-turbo", # GPT model or Local model using HuggingFace model paths
base_url: str = None,
truncate: bool = False,
**kwargs,
) -> None: ) -> None:
""" """
:param model: str :param model: str
Implements an OpenAI-style chat completion API for
accessing both OpenAI OR locally-hosted models using
HuggingFace Tokenizer
OpenAI API model (e.g. gpt-3.5-turbo) OpenAI API model (e.g. gpt-3.5-turbo)
using the **gen_kwargs passed on init
:param truncate: bool :param truncate: bool
Truncate input if too long (if False and input is too long, throw error) Truncate input if too long (if False and input is too long, throw error)
""" """
super().__init__() super().__init__()
try: try:
import openai, tiktoken # noqa: E401 import openai # noqa: E401
except ModuleNotFoundError: except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \ "attempted to use 'openai' LM type, but package `openai` or `tiktoken` are not installed. \
please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`", please install these via `pip install lm-eval[openai]` or `pip install -e .[openai]`",
) )
self.model = model self.model = model
self.frequency_penalty = 0 self.base_url = base_url
self.logit_bias = None
self.n = 1
self.presence_penalty = 0
self.temperature = 1
self.top_p = 1
self.tokenizer = tiktoken.encoding_for_model(self.model)
self.vocab_size = self.tokenizer.n_vocab
self.truncate = truncate self.truncate = truncate
self.end_of_text_token_id = self.tokenizer.eot_token
# Read from environment variable OPENAI_API_KEY # Read from environment variable OPENAI_API_KEY
# Set to EMPTY for local
if self.base_url:
self.client = openai.OpenAI(base_url=self.base_url)
else:
self.client = openai.OpenAI() # openai.AsyncOpenAI() self.client = openai.OpenAI() # openai.AsyncOpenAI()
@property
def eot_token_id(self):
return self.end_of_text_token_id
@property @property
def max_length(self) -> int: def max_length(self) -> int:
# Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token # Note: the OpenAI API supports up to 2049 tokens, with the first token being the first input token
...@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM): ...@@ -408,53 +416,19 @@ class OpenaiChatCompletionsLM(LM):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str) -> List[int]:
return self.tokenizer.encode(string)
def tok_decode(self, tokens: List[int]) -> str:
return self.tokenizer.decode(tokens)
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
whole_enc = self.tok_encode(context + continuation)
context_enc = self.tok_encode(context)
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
def _collate(x):
toks = self.tok_encode(x[0])
return -len(toks), x[0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x.args[1])) grouper = utils.Grouper(requests, lambda x: str(x.args[1]))
for key, reqs in grouper.get_grouped().items(): for key, reqs in grouper.get_grouped().items():
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer(
[req.args for req in reqs], lambda x: (-len(x[0]), x[0])
def sameuntil_chunks(xs, size): )
ret = []
lastuntil = xs[0][1]
for x in xs:
if len(ret) >= size or x[1] != lastuntil:
yield ret, lastuntil
ret = []
lastuntil = x[1]
ret.append(x)
if ret:
yield ret, lastuntil
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
...@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM): ...@@ -468,37 +442,26 @@ class OpenaiChatCompletionsLM(LM):
gen_kwargs = all_gen_kwargs[0] gen_kwargs = all_gen_kwargs[0]
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if "until" in kwargs.keys(): if "until" in kwargs.keys():
until = kwargs.pop("until") until = kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
until = [kwargs] until = [kwargs]
elif not isinstance(until, list): elif not isinstance(until, list):
raise ValueError( raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" f"Expected repr(kwargs['until']) to be of type Union[str, list] but got {until}"
) )
kwargs["stop"] = until
kwargs["max_tokens"] = kwargs.pop("max_gen_toks", self.max_gen_toks)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}" f"Expected repr(kwargs) to be of type repr(dict) but got {kwargs}"
) )
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
response = oa_chat_completion( response = oa_chat_completion(
client=self.client, client=self.client, messages=inps, model=self.model, **kwargs
messages=inps,
model=self.model,
frequency_penalty=self.frequency_penalty,
# logit_bias=self.logit_bias,
max_tokens=max_gen_toks,
n=self.n,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
top_p=self.top_p,
) )
for resp, (context, args_) in zip(response.choices, chunk): for resp, (context, args_) in zip(response.choices, chunk):
......
...@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html ...@@ -13,11 +13,13 @@ Homepage: https://textsynth.com/index.html
""" """
import logging import logging
import os import os
import requests as _requests import requests as _requests
import time
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import retry_on_specific_exceptions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs): ...@@ -27,21 +29,26 @@ def textsynth_completion(**kwargs):
"""Query TextSynth API for completion. """Query TextSynth API for completion.
Retry with back-off until they respond. Retry with back-off until they respond.
""" """
backoff_time = 3
while True: def _exception_callback(e: Exception, sleep_time: float) -> None:
try:
return _requests.post(**kwargs)
except _requests.exceptions.RequestException:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
time.sleep(backoff_time)
backoff_time *= 1.5 @retry_on_specific_exceptions(
on_exceptions=[_requests.exceptions.RequestException],
max_retries=None, # retry forever, consider changing
on_exception_callback=_exception_callback,
)
def completion():
return _requests.post(**kwargs)
return completion()
@register_model("textsynth") @register_model("textsynth")
class TextSynthLM(LM): class TextSynthLM(LM):
def __init__(self, engine, truncate: bool = False) -> None: def __init__(self, engine, truncate: bool = False, **kwargs) -> None:
""" """
:param engine: str :param engine: str
TextSynth API engine (e.g. `gptj_6B`) TextSynth API engine (e.g. `gptj_6B`)
...@@ -149,7 +156,7 @@ class TextSynthLM(LM): ...@@ -149,7 +156,7 @@ class TextSynthLM(LM):
self.cache_hook.add_partial("generate_until", (inp, request_args), s) self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " "The following response does not contain generated `text`. "
"Got:\n{resp}" "Got:\n{resp}"
) )
assert False assert False
......
from collections import defaultdict
from typing import List, Tuple, Optional, Literal, Union, Any
from transformers import AutoTokenizer
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
import copy import copy
from importlib.util import find_spec
from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval.utils import (
Collator,
divide,
eval_logger,
get_rolling_token_windows,
make_disjoint_window,
)
try: try:
from vllm import LLM, SamplingParams import ray
from ray.util.multiprocessing import Pool from ray.util.multiprocessing import Pool
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
eval_logger = utils.eval_logger eval_logger = eval_logger
# adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727 # adapted from https://github.com/vllm-project/vllm/issues/367#issuecomment-1788341727
def run_inference_one_model(model_args: dict, sampling_params, requests: List[int]): def run_inference_one_model(
# gpu_id = [x for x in gpu_id] model_args: dict, sampling_params, requests: List[List[int]]
# os.environ["CUDA_VISIBLE_DEVICES"]= str(gpu_id) ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params) return llm.generate(prompt_token_ids=requests, sampling_params=sampling_params)
...@@ -40,7 +49,7 @@ class VLLM(LM): ...@@ -40,7 +49,7 @@ class VLLM(LM):
tokenizer_mode: Literal["auto", "slow"] = "auto", tokenizer_mode: Literal["auto", "slow"] = "auto",
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
quantization: Optional[Literal["awq"]] = None, quantization: Optional[str] = None,
max_gen_toks: int = 256, max_gen_toks: int = 256,
swap_space: int = 4, swap_space: int = 4,
batch_size: Union[str, int] = 1, batch_size: Union[str, int] = 1,
...@@ -54,12 +63,10 @@ class VLLM(LM): ...@@ -54,12 +63,10 @@ class VLLM(LM):
): ):
super().__init__() super().__init__()
try: if not find_spec("vllm"):
import vllm
except ModuleNotFoundError:
raise Exception( raise Exception(
"attempted to use 'vllm' LM type, but package `vllm` is not installed. \ "attempted to use 'vllm' LM type, but package `vllm` is not installed. "
please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`", "Please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
) )
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
...@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -85,17 +92,30 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
} }
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
else batch_size
)
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else: else:
self.model_args["worker_use_ray"] = True self.model_args["worker_use_ray"] = True
self.batch_size = "auto"
eval_logger.info("Manual batching is not compatible with data parallelism.")
from transformers import AutoConfig
self._config = AutoConfig.from_pretrained(
pretrained, trust_remote_code=trust_remote_code, revision=revision
)
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
tokenizer if tokenizer else pretrained, tokenizer if tokenizer else pretrained,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
) )
self.batch_size = batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
@property @property
...@@ -107,7 +127,16 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -107,7 +127,16 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
def max_length(self): def max_length(self):
if self._max_length: # if max length manually set, return it if self._max_length: # if max length manually set, return it
return self._max_length return self._max_length
if self.data_parallel_size <= 1:
return self.model.llm_engine.model_config.max_model_len
else:
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self._config, attr):
return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"): if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
...@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -155,13 +184,13 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel_size > 1: if self.data_parallel_size > 1:
requests = [ requests = [list(x) for x in divide(requests, self.data_parallel_size)]
list(x) for x in utils.divide(requests, self.data_parallel_size)
]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel_size) as pool: with Pool(self.data_parallel_size) as pool:
results = pool.starmap(run_inference_one_model, inputs) results = pool.starmap(run_inference_one_model, inputs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
ray.shutdown()
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
...@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -170,7 +199,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
) )
return outputs return outputs
def _encode_pair( def _encode_pair(
...@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -193,8 +221,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode( context_enc, continuation_enc = (
continuation [self.eot_token_id],
self.tok_encode(continuation),
) )
else: else:
context_enc, continuation_enc = self._encode_pair(context, continuation) context_enc, continuation_enc = self._encode_pair(context, continuation)
...@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -209,8 +238,8 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests]):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, make_disjoint_window,
utils.get_rolling_token_windows( get_rolling_token_windows(
token_list=self.tok_encode(string), token_list=self.tok_encode(string),
prefix_token=self.eot_token_id, prefix_token=self.eot_token_id,
max_seq_len=self.max_length - 1, max_seq_len=self.max_length - 1,
...@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -233,8 +262,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
return loglikelihoods return loglikelihoods
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(self, requests: List[Instance]) -> List[str]:
res = defaultdict(list) res = []
re_ords = {}
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
...@@ -250,24 +278,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -250,24 +278,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), tuple(_requests[0][1]) return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
grouper = utils.Grouper(requests, lambda x: str(x[1])) re_ords = Collator(requests, _collate_gen, grouping=True)
for key, reqs in grouper.get_grouped().items(): chunks = re_ords.get_batched(
# within each set of reqs for given kwargs, we reorder by token length, descending. n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
re_ords[key] = utils.Reorderer(requests, _collate_gen) )
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
chunks = utils.chunks(
re_ord.get_reordered(),
n=int(self.batch_size) if self.batch_size != "auto" else 0,
fn=None,
)
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
...@@ -302,8 +324,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -302,8 +324,6 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding] context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# TODO: max_length in kwargs
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding,
...@@ -316,18 +336,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -316,18 +336,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
res[key].append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form
res[key] = re_ord.get_original(res[key])
pbar.close() pbar.close()
# reorder all group of results back to original unsorted form
return grouper.get_original(res) return re_ords.get_original(res)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, self,
...@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -340,16 +357,15 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
toks = x[1] + x[2] toks = x[1] + x[2]
return -len(toks), tuple(toks) return -len(toks), tuple(toks)
re_ord = utils.Reorderer(requests, _collate) # Reorder requests by length and batch
re_ord = Collator(requests, sort_fn=_collate)
chunks = utils.chunks( chunks = re_ord.get_batched(
re_ord.get_reordered(), n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
n=int(self.batch_size) if self.batch_size != "auto" else 0,
fn=None,
) )
pbar = tqdm(total=len(requests), disable=disable_tqdm) pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks: for chunk in chunks:
inps = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
inp = (context_enc + continuation_enc)[-(self.max_length) :] inp = (context_enc + continuation_enc)[-(self.max_length) :]
...@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -357,18 +373,18 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
0, len(context_enc) + len(continuation_enc) - (self.max_length) 0, len(context_enc) + len(continuation_enc) - (self.max_length)
) )
inps.append(inp) inputs.append(inp)
ctxlens.append(ctxlen) ctxlens.append(ctxlen)
outputs = self._model_generate(requests=inps, generate=False) outputs = self._model_generate(requests=inputs, generate=False)
for output, ctxlen, (cache_key, context_enc, continuation_enc) in zip( for output, ctxlen, (cache_key, _, _), inp in zip(
outputs, ctxlens, chunk outputs, ctxlens, chunk, inputs
): ):
answer = self._parse_logprobs( answer = self._parse_logprobs(
(context_enc + continuation_enc), tokens=inp,
output, outputs=output,
ctxlen, ctxlen=ctxlen,
) )
res.append(answer) res.append(answer)
...@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -385,9 +401,9 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
Tokens from context+continuations Input tokens (potentially left-truncated)
:param outputs: RequestOutput :param outputs: RequestOutput
Contains prompt Contains prompt_logprobs
:param ctxlen: int :param ctxlen: int
Length of context (so we can slice them away and only keep the predictions) Length of context (so we can slice them away and only keep the predictions)
:return: :return:
...@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -397,11 +413,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
Whether argmax matches given continuation exactly Whether argmax matches given continuation exactly
""" """
# prompt_logprobs = [None, {}*len(context-1)] # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
continuation_logprobs_dicts = outputs.prompt_logprobs continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs # Calculate continuation_logprobs
# assume ctxlen always > 1 # assume ctxlen always >= 1
continuation_logprobs = sum( continuation_logprobs = sum(
logprob_dict.get(token) logprob_dict.get(token)
for token, logprob_dict in zip( for token, logprob_dict in zip(
......
...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None ...@@ -69,7 +69,6 @@ def get_prompt(prompt_id: str, dataset_name: str = None, subset_name: str = None
def load_prompt_list( def load_prompt_list(
use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs use_prompt: str, dataset_name=None, subset_name=None, yaml_path=None, **kwargs
): ):
category_name, prompt_name = use_prompt.split(":") category_name, prompt_name = use_prompt.split(":")
if category_name == "promptsource": if category_name == "promptsource":
...@@ -113,7 +112,6 @@ class PromptString: ...@@ -113,7 +112,6 @@ class PromptString:
self.prompt_string = prompt_string self.prompt_string = prompt_string
def apply(self, doc): def apply(self, doc):
doc_to_text = self.prompt_string["doc_to_text"] doc_to_text = self.prompt_string["doc_to_text"]
doc_to_target = self.prompt_string["doc_to_target"] doc_to_target = self.prompt_string["doc_to_target"]
......
...@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) - ...@@ -61,11 +61,27 @@ def register_configurable_group(config: Dict[str, str], yaml_path: str = None) -
task_list = [task for task in all_task_list if type(task) == str] task_list = [task for task in all_task_list if type(task) == str]
for task_config in config_list: for task_config in config_list:
base_config = {}
task_name_config = {}
if "task" in task_config:
task_name = task_config["task"]
if task_name in ALL_TASKS:
task_obj = get_task_dict(task_name)[task_name]
if type(task_obj) == tuple:
_, task_obj = task_obj
if task_obj is not None:
base_config = task_obj._config.to_dict()
task_name_config["task"] = f"{group}_{task_name}"
task_config = utils.load_yaml_config(yaml_path, task_config) task_config = utils.load_yaml_config(yaml_path, task_config)
var_configs = check_prompt_config( var_configs = check_prompt_config(
{ {
**base_config,
**task_config, **task_config,
**{"group": group}, **{"group": group},
**task_name_config,
}, },
yaml_path=os.path.dirname(yaml_path), yaml_path=os.path.dirname(yaml_path),
) )
...@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None: ...@@ -131,7 +147,10 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
""" """
Calling this function Calling this function
""" """
for root, subdirs, file_list in reversed(list(os.walk(task_dir))):
# Track whether any tasks failed during loading
import_fail = False
for root, subdirs, file_list in os.walk(task_dir):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0): # if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list: for f in file_list:
if f.endswith(".yaml"): if f.endswith(".yaml"):
...@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None: ...@@ -155,20 +174,27 @@ def include_task_folder(task_dir: str, register_task: bool = True) -> None:
# Log this silently and show it only when # Log this silently and show it only when
# the user defines the appropriate verbosity. # the user defines the appropriate verbosity.
except ModuleNotFoundError as e: except (ImportError, ModuleNotFoundError) as e:
import_fail = True
eval_logger.debug( eval_logger.debug(
f"{yaml_path}: {e}. Config will not be added to registry." f"{yaml_path}: {e}. Config will not be added to registry."
) )
except Exception as error: except Exception as error:
import traceback import traceback
eval_logger.debug( eval_logger.warning(
"Failed to load config in\n" "Unexpected error loading config in\n"
f" {yaml_path}\n" f" {yaml_path}\n"
" Config will not be added to registry\n" " Config will not be added to registry\n"
f" Error: {error}\n" f" Error: {error}\n"
f" Traceback: {traceback.format_exc()}" f" Traceback: {traceback.format_exc()}"
) )
if import_fail:
eval_logger.warning(
"Some tasks could not be loaded due to missing dependencies."
" Run with `--verbosity DEBUG` for full details."
)
return 0 return 0
...@@ -180,7 +206,6 @@ def include_path(task_dir): ...@@ -180,7 +206,6 @@ def include_path(task_dir):
def initialize_tasks(verbosity="INFO"): def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/" task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
......
...@@ -23,4 +23,4 @@ metric_list: ...@@ -23,4 +23,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -20,4 +20,4 @@ metric_list: ...@@ -20,4 +20,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -13,4 +13,4 @@ metric_list: ...@@ -13,4 +13,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -11,4 +11,4 @@ metric_list: ...@@ -11,4 +11,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -24,7 +24,6 @@ def parse_args(): ...@@ -24,7 +24,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
# get filename of base_yaml so we can `"include": ` it in our other YAMLs. # get filename of base_yaml so we can `"include": ` it in our other YAMLs.
...@@ -37,7 +36,6 @@ if __name__ == "__main__": ...@@ -37,7 +36,6 @@ if __name__ == "__main__":
dataset_path = "lukaemon/bbh" dataset_path = "lukaemon/bbh"
for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()): for task in tqdm(datasets.get_dataset_infos(dataset_path).keys()):
resp = requests.get( resp = requests.get(
f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt" f"https://raw.githubusercontent.com/suzgunmirac/BIG-Bench-Hard/main/cot-prompts/{task}.txt"
).content.decode("utf-8") ).content.decode("utf-8")
......
...@@ -27,4 +27,4 @@ filter_list: ...@@ -27,4 +27,4 @@ filter_list:
- function: "take_first" - function: "take_first"
num_fewshot: 0 num_fewshot: 0
metadata: metadata:
- version: 1.0 version: 2.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment