Unverified Commit 4f0a7e57 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #648 from EleutherAI/edge-case-lowbits

[Refactor] Misc. bugfixes ; edgecase quantized models
parents 42c10bd6 7f557daa
...@@ -14,32 +14,42 @@ Tasks are configured via the `TaskConfig` object. Below, we describe all fields ...@@ -14,32 +14,42 @@ Tasks are configured via the `TaskConfig` object. Below, we describe all fields
### Parameters ### Parameters
Task naming + registration:
- **task** (`str`, defaults to None) — name of the task. - **task** (`str`, defaults to None) — name of the task.
- **group** (`str`, *optional*) — name of the task group(s) a task belongs to. Enables one to run all tasks with a specified tag or group name at once. - **group** (`str`, *optional*) — name of the task group(s) a task belongs to. Enables one to run all tasks with a specified tag or group name at once.
- **reference** (`str`, *optional*) —
Dataset configuration options:
- **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub. - **dataset_path** (`str`) — The name of the dataset as listed by HF in the datasets Hub.
- **dataset_name** (`str`, *optional*, defaults to None) — The name of, what HF calls, a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.) - **dataset_name** (`str`, *optional*, defaults to None) — The name of, what HF calls, a “data instance” or sub-task of the benchmark. If your task does not contain any data instances, just leave this to default to None. (If you're familiar with the HF `datasets.load_dataset` function, these are just the first 2 arguments to it.)
- **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv. - **dataset_kwargs** (`dict`, *optional*) — Auxiliary arguments that `datasets.load_dataset` accepts. This can be used to specify arguments such as `data_files` or `data_dir` if you want to use local datafiles such as json or csv.
- **training_split** (`str`, *optional*) — Split in the dataset to use as the training split. - **training_split** (`str`, *optional*) — Split in the dataset to use as the training split.
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split. - **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split. - **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
- **fewshot_split** (`str`, *optional*) — assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) - **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
- **template_aliases** (`str`, *optional*) —
- **aliases**: (`Union[str, list]`, *optional*) — Prompting / in-context formatting options:
- **template_aliases** (`str`, *optional*) — A field for inputting additional Jinja2 content. Intended not to render as text after applying a Jinja template, but to instead define variables within Jinja that will be used within the written prompts. (for example, mapping the dataset column `label` to the new name `gold`).
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text and doc_to_target and make template_aliases unused.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model.
- **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
Runtime configuration options:
- **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input. - **num_fewshot** (`int`, *optional*, defaults to 0) — Number of few-shot examples before the input.
- **batch_size** (`int`, *optional*, defaults to 1) — Batch size. - **batch_size** (`int`, *optional*, defaults to 1) — Batch size.
- **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs for each sample. can be used for cases such as self-consistency.
Scoring details:
- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format. - **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
- **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`.
- **output_type** (`str`, *optional*, defaults to "greedy_until") — Selects the type of model output for the given task. Options are `greedy_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`. - **output_type** (`str`, *optional*, defaults to "greedy_until") — Selects the type of model output for the given task. Options are `greedy_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes. - **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
- **delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. can be used for cases such as self-consistency.
- **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API. - **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
- **should_decontaminate** (`bool`, *optional*, defaults to False) - - **should_decontaminate** (`bool`, *optional*, defaults to False) -
- **doc_to_decontamination_query** (`str`, *optional*) — - **doc_to_decontamination_query** (`str`, *optional*) —
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use, if defined will overwrite doc_to_text and doc_to_target.
Other:
- **metadata** (`str`, *optional*) — An optional field where arbitrary metadata can be passed. - **metadata** (`str`, *optional*) — An optional field where arbitrary metadata can be passed.
## Filters ## Filters
......
# Description Guide
![fewshot-example](./img/fewshot_example_gpt3.png)
(Figure from [Brown et al., 2020](https://arxiv.org/pdf/2005.14165.pdf))
Task descriptions provide in-context task instruction for your language model. If you'd like to prepend a natural language description to your few-shot examples and prompt, you can do so on a per-task basis via the `description_dict` arg of [`evaluator.evaluate`](../lm_eval/evaluator.py). This `description_dict` must adhere to the following key-value structure:
- **key**: the task name (`str`) as specified in the lm-eval-harness [task registry](../lm_eval/tasks/__init__.py).
- **value**: the corresponding (`str`) description/prompt for the task identified by **key**.
```python
description_dict = {
"task_name_1": "description",
"task_name_2": "description",
...
}
```
Note that a task's description will be separated from its following few-shot examples and prompt by a new line as such:
```python
"""
<description>
<examples>
<prompt>
"""
```
## Descriptions in File
One can also interface with the aforementioned [`evaluator.evaluate`](../lm_eval/evaluator.py) (or `evaluator.simple_evaluate`) method from a higher level by simply passing a JSON file path to the `description_dict_path` arg of the command-line interface (CLI) program, `main.py`. The JSON file pointed to should be structured the same as the `description_dict`. E.g. for some file at `/your/path/descriptions.json` you may have:
```json
{
"cycle_letters": "Please unscramble the letters into a word, and write that word:",
"copa": "Given a premise and one alternative with a causal relation to the premise and another without, choose the more plausible alternative"
}
```
which can then be supplied to the CLI as:
```bash
python main.py \
--tasks cycle_letters,copa \
--description_dict_path /your/path/descriptions.json \
...
```
...@@ -156,3 +156,17 @@ def get_aggregation(name): ...@@ -156,3 +156,17 @@ def get_aggregation(name):
raise Warning( raise Warning(
"{} not a registered aggregation metric!".format(name), "{} not a registered aggregation metric!".format(name),
) )
def get_default_aggregation(metric_name):
try:
return DEFAULT_AGGREGATION_REGISTRY[metric_name]
except KeyError:
raise Warning(f"No default aggregation metric for metric '{metric_name}'!")
def is_higher_better(metric_name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
...@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger ...@@ -24,19 +24,18 @@ from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.api.metrics import ( from lm_eval.api.metrics import (
# get_metric,
# get_aggregation,
mean, mean,
weighted_perplexity, weighted_perplexity,
bits_per_byte, bits_per_byte,
) )
from lm_eval.api.registry import ( from lm_eval.api.registry import (
METRIC_REGISTRY, get_metric,
get_aggregation,
get_default_aggregation,
is_higher_better,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
OUTPUT_TYPE_REGISTRY, OUTPUT_TYPE_REGISTRY,
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
DEFAULT_AGGREGATION_REGISTRY,
) )
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -49,10 +48,12 @@ ALL_OUTPUT_TYPES = [ ...@@ -49,10 +48,12 @@ ALL_OUTPUT_TYPES = [
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry
task: str = None task: str = None
group: Union[str, list] = None group: Union[str, list] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
dataset_kwargs: dict = None dataset_kwargs: dict = None
...@@ -60,23 +61,24 @@ class TaskConfig(dict): ...@@ -60,23 +61,24 @@ class TaskConfig(dict):
validation_split: str = None validation_split: str = None
test_split: str = None test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
template_aliases: str = None template_aliases: str = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
gold_alias: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
# runtime configuration options
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
repeats: int = 1 # scoring options
metric_list: str = None metric_list: str = None
gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
...@@ -514,13 +516,11 @@ class ConfigurableTask(Task): ...@@ -514,13 +516,11 @@ class ConfigurableTask(Task):
if self._config.metric_list is None: if self._config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ? # TODO: handle this in TaskConfig.__post_init__ ?
for metric_name in _metric_list: for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = DEFAULT_AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = get_default_aggregation(
metric_name
]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name metric_name
] )
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self._config.metric_list: for metric_config in self._config.metric_list:
assert "metric" in metric_config assert "metric" in metric_config
...@@ -530,30 +530,13 @@ class ConfigurableTask(Task): ...@@ -530,30 +530,13 @@ class ConfigurableTask(Task):
for key in metric_config for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"] if key not in ["metric", "aggregation", "higher_is_better"]
} }
try: self._metric_fn_list[metric_name] = get_metric(metric_name)
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name] self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
eval_logger.warning(
f"Metric {metric_name} not found, "
"Searching from https://huggingface.co/evaluate-metric"
)
try:
metric_object = evaluate.load(metric_name)
self._metric_fn_list[metric_name] = metric_object
self._metric_fn_kwargs[metric_name] = kwargs
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
if "aggregation" in metric_config: if "aggregation" in metric_config:
agg_name = metric_config["aggregation"] agg_name = metric_config["aggregation"]
if type(agg_name) == str: if type(agg_name) == str:
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[ self._aggregation_list[metric_name] = get_aggregation(agg_name)
agg_name
]
elif callable(agg_name): elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config[ self._aggregation_list[metric_name] = metric_config[
"aggregation" "aggregation"
...@@ -561,7 +544,7 @@ class ConfigurableTask(Task): ...@@ -561,7 +544,7 @@ class ConfigurableTask(Task):
else: else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()} INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = DEFAULT_AGGREGATION_REGISTRY[metric_name] metric_agg = get_default_aggregation(metric_name)
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but aggregation is not. " f"metric {metric_name} is defined, but aggregation is not. "
f"using default " f"using default "
...@@ -577,11 +560,9 @@ class ConfigurableTask(Task): ...@@ -577,11 +560,9 @@ class ConfigurableTask(Task):
eval_logger.warning( eval_logger.warning(
f"metric {metric_name} is defined, but higher_is_better is not. " f"metric {metric_name} is defined, but higher_is_better is not. "
f"using default " f"using default "
f"higher_is_better={HIGHER_IS_BETTER_REGISTRY[metric_name]}" f"higher_is_better={is_higher_better(metric_name)}"
) )
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[ self._higher_is_better[metric_name] = is_higher_better(metric_name)
metric_name
]
self.download(self._config.dataset_kwargs) self.download(self._config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -834,7 +815,6 @@ class ConfigurableTask(Task): ...@@ -834,7 +815,6 @@ class ConfigurableTask(Task):
else: else:
gold = int(self.doc_to_target(doc)) gold = int(self.doc_to_target(doc))
pred = np.argmax(lls)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval( choices = ast.literal_eval(
utils.apply_template( utils.apply_template(
...@@ -852,6 +832,8 @@ class ConfigurableTask(Task): ...@@ -852,6 +832,8 @@ class ConfigurableTask(Task):
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[::2]
pred = np.argmax(lls)
acc = 1.0 if np.argmax(lls) == gold else 0.0 acc = 1.0 if np.argmax(lls) == gold else 0.0
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0 acc_norm = 1.0 if np.argmax(lls / completion_len) == gold else 0.0
...@@ -863,7 +845,6 @@ class ConfigurableTask(Task): ...@@ -863,7 +845,6 @@ class ConfigurableTask(Task):
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
} }
# TODO: set which normalization metrics should be reported, and calculate them
if "exact_match" in self._metric_fn_list.keys(): if "exact_match" in self._metric_fn_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = is_greedy[gold] # take value for the gold answer is_greedy = is_greedy[gold] # take value for the gold answer
...@@ -884,7 +865,7 @@ class ConfigurableTask(Task): ...@@ -884,7 +865,7 @@ class ConfigurableTask(Task):
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key](
references=[gold], references=[gold],
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[key], **self._metric_fn_kwargs[key],
......
...@@ -45,7 +45,6 @@ def simple_evaluate( ...@@ -45,7 +45,6 @@ def simple_evaluate(
check_integrity=False, check_integrity=False,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out=False,
output_base_path=None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -74,8 +73,6 @@ def simple_evaluate( ...@@ -74,8 +73,6 @@ def simple_evaluate(
Whether to run the relevant part of the test suite for the tasks Whether to run the relevant part of the test suite for the tasks
:param write_out: bool :param write_out: bool
If True, write details about prompts and logits to json for all tasks If True, write details about prompts and logits to json for all tasks
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir.
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -121,7 +118,6 @@ def simple_evaluate( ...@@ -121,7 +118,6 @@ def simple_evaluate(
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path, decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out, write_out=write_out,
output_base_path=output_base_path,
) )
if lm.rank == 0: if lm.rank == 0:
...@@ -158,7 +154,6 @@ def evaluate( ...@@ -158,7 +154,6 @@ def evaluate(
bootstrap_iters=100000, bootstrap_iters=100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
write_out=False, write_out=False,
output_base_path=None,
): ):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
...@@ -174,8 +169,6 @@ def evaluate( ...@@ -174,8 +169,6 @@ def evaluate(
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param write_out: bool :param write_out: bool
If True, write all prompts, logits and metrics to json for offline analysis If True, write all prompts, logits and metrics to json for offline analysis
:param output_base_path: str, optional
Directory to which detailed eval info will be written. Defaults to present working dir
:return :return
Dictionary of results Dictionary of results
""" """
......
...@@ -115,9 +115,10 @@ class HFLM(LM): ...@@ -115,9 +115,10 @@ class HFLM(LM):
else torch.device("cpu") else torch.device("cpu")
) )
else: else:
eval_logger.info( if device != "cuda":
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." eval_logger.info(
) f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too # TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device self._device = device
...@@ -204,7 +205,12 @@ class HFLM(LM): ...@@ -204,7 +205,12 @@ class HFLM(LM):
self.model.tie_weights() self.model.tie_weights()
if gpus <= 1 and not parallelize: if gpus <= 1 and not parallelize:
# place model onto device, if not using HF Accelerate in any form # place model onto device, if not using HF Accelerate in any form
self.model.to(self.device) try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
...@@ -246,7 +252,12 @@ class HFLM(LM): ...@@ -246,7 +252,12 @@ class HFLM(LM):
if torch.cuda.is_available() if torch.cuda.is_available()
else torch.device("cpu") else torch.device("cpu")
) )
self.model.to(self.device) try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
else: else:
self._model = accelerator.prepare(self.model) self._model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......
...@@ -19,6 +19,6 @@ metric_list: ...@@ -19,6 +19,6 @@ metric_list:
- metric: acc_norm - metric: acc_norm
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
- metric: acc_mutual_info # - metric: acc_mutual_info
aggregation: mean # aggregation: mean
higher_is_better: true # higher_is_better: true
...@@ -8,7 +8,13 @@ training_split: train ...@@ -8,7 +8,13 @@ training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}" doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int. gold_alias: " {{answer_choices[label]}}" # this will be cast to an int.
generation_kwargs:
until:
- "\n\n"
- "\n"
do_sample: false
temperature: 0.0
template_aliases: "{% set answer_choices = ['no', 'yes'] %}" template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list: metric_list:
- metric: exact_match - metric: exact_match
......
group: group:
- super-glue-promptsource - super-glue-promptsource
task: "GPT-3 style" task: "rte"
dataset_path: super_glue dataset_path: super_glue
dataset_name: rte dataset_name: rte
training_split: train training_split: train
validation_split: validation validation_split: validation
use_prompt: "promptsource:GPT-3 style" use_prompt: "promptsource:GPT-3 style"
generation_kwargs:
until:
- "\n"
- "\n\n"
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
......
...@@ -94,10 +94,10 @@ class MultiChoice: ...@@ -94,10 +94,10 @@ class MultiChoice:
def __contains__(self, values): def __contains__(self, values):
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
for choice in self.choices: for choice in self.choices:
eval_logger.info(f" - {choice}") eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
return True return True
def __iter__(self): def __iter__(self):
...@@ -468,7 +468,8 @@ def pad_and_concat( ...@@ -468,7 +468,8 @@ def pad_and_concat(
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
if padding_side == "right": if padding_side == "right":
......
...@@ -43,7 +43,6 @@ def parse_args(): ...@@ -43,7 +43,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args() return parser.parse_args()
...@@ -90,7 +89,6 @@ def main(): ...@@ -90,7 +89,6 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
output_base_path=args.output_base_path,
) )
if results is not None: if results is not None:
......
...@@ -43,7 +43,7 @@ setuptools.setup( ...@@ -43,7 +43,7 @@ setuptools.setup(
"sacrebleu==1.5.0", "sacrebleu==1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict", "sqlitedict",
"torch>=1.7", "torch>=1.8",
"tqdm-multiprocess", "tqdm-multiprocess",
"transformers>=4.1", "transformers>=4.1",
"zstandard", "zstandard",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment