"vscode:/vscode.git/clone" did not exist on "c6cb4d81cee393dc1e41600604f98ec15c355e90"
Commit 4288b53e authored by Baber's avatar Baber
Browse files

Merge branch 'main' into llama

parents 37eb9c9d 94344a61
import warnings
import torch
import torch.nn as nn
from transformer_lens import HookedTransformer
from transformers import AutoConfig
from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
def evaluate_lm_eval(lens_model: HookedTransformer, tasks: list[str], **kwargs):
class HFLikeModelAdapter(nn.Module):
"""Adapts HookedTransformer to match the HuggingFace interface expected by lm-eval"""
def __init__(self, model: HookedTransformer):
super().__init__()
self.model = model
self.tokenizer = model.tokenizer
self.config = AutoConfig.from_pretrained(model.cfg.tokenizer_name)
self.device = model.cfg.device
self.tie_weights = lambda: self
def forward(self, input_ids=None, attention_mask=None, **kwargs):
output = self.model(input_ids, attention_mask=attention_mask, **kwargs)
# Make sure output has the expected .logits attribute
if not hasattr(output, "logits"):
if isinstance(output, torch.Tensor):
output.logits = output
return output
# Only delegate specific attributes we know we need
def to(self, *args, **kwargs):
return self.model.to(*args, **kwargs)
def eval(self):
self.model.eval()
return self
def train(self, mode=True):
self.model.train(mode)
return self
model = HFLikeModelAdapter(lens_model)
warnings.filterwarnings("ignore", message="Failed to get model SHA for")
results = evaluator.simple_evaluate(
model=HFLM(pretrained=model, tokenizer=model.tokenizer),
tasks=tasks,
verbosity="WARNING",
**kwargs,
)
return results
if __name__ == "__main__":
# Load base model
model = HookedTransformer.from_pretrained("pythia-70m")
res = evaluate_lm_eval(model, tasks=["arc_easy"])
print(res["results"])
...@@ -456,6 +456,7 @@ class Task(abc.ABC): ...@@ -456,6 +456,7 @@ class Task(abc.ABC):
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config["task"], doc_id, self.config.repeats),
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
chat_template=chat_template,
) )
if not isinstance(inst, list): if not isinstance(inst, list):
...@@ -1098,6 +1099,8 @@ class ConfigurableTask(Task): ...@@ -1098,6 +1099,8 @@ class ConfigurableTask(Task):
if apply_chat_template: if apply_chat_template:
if self.multiple_input: if self.multiple_input:
# TODO: append prefill? # TODO: append prefill?
if not labeled_examples:
return ""
return chat_template(labeled_examples) return chat_template(labeled_examples)
if isinstance(example, str): if isinstance(example, str):
self.append_target_question( self.append_target_question(
...@@ -1350,6 +1353,7 @@ class ConfigurableTask(Task): ...@@ -1350,6 +1353,7 @@ class ConfigurableTask(Task):
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> Union[List[Instance], Instance]:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
aux_arguments = None aux_arguments = None
...@@ -1364,9 +1368,20 @@ class ConfigurableTask(Task): ...@@ -1364,9 +1368,20 @@ class ConfigurableTask(Task):
target_delimiter = "" target_delimiter = ""
if self.multiple_input: if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx # If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc) cont = self.doc_to_target(doc)
arguments = [ arguments = [
(ctx + choice, f"{target_delimiter}{cont}") for choice in choices (
ctx
+ (
chat_template([{"role": "user", "content": choice}])
if apply_chat_template
else choice
),
f"{target_delimiter}{cont}",
)
for choice in choices
] ]
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
......
...@@ -271,7 +271,9 @@ class VLLM_VLM(VLLM): ...@@ -271,7 +271,9 @@ class VLLM_VLM(VLLM):
left_truncate_len=max_ctx_len, left_truncate_len=max_ctx_len,
) )
cont = self._model_generate(inputs, stop=until, generate=True, **kwargs) cont = self._model_generate(
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
)
for output, context in zip(cont, contexts): for output, context in zip(cont, contexts):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
......
...@@ -86,6 +86,7 @@ ...@@ -86,6 +86,7 @@
| [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English | | [mmlu_pro](mmlu_pro/README.md) | A refined set of MMLU, integrating more challenging, reasoning-focused questions and expanding the choice set from four to ten options. | English |
| [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English | | [mmlusr](mmlusr/README.md) | Variation of MMLU designed to be more rigorous. | English |
| model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | | | model_written_evals | Evaluation tasks auto-generated for evaluating a collection of AI Safety concerns. | |
| [moral_stories](moral_stories/README.md) | A crowd-sourced dataset of structured narratives that describe normative and norm-divergent actions taken by individuals to accomplish certain intentions in concrete situations. | English
| [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English | | [mutual](mutual/README.md) | A retrieval-based dataset for multi-turn dialogue reasoning. | English |
| [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English | | [nq_open](nq_open/README.md) | Open domain question answering tasks based on the Natural Questions dataset. | English |
| [okapi/arc_multilingual](okapi/arc_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** | | [okapi/arc_multilingual](okapi/arc_multilingual/README.md) | Tasks that involve reading comprehension and information retrieval challenges. | Multiple (31 languages) **Machine Translated.** |
......
import yaml
languages = [
"en",
"ar",
"fr",
"es",
"hi",
"de",
"id",
"it",
"ja",
"ko",
"pt",
"zh",
"yo",
"bn",
"sw",
]
def main() -> None:
for language in languages:
file_name = f"global_mmlu_{language}.yaml"
try:
with open(f"{file_name}", "w") as f:
f.write("# Generated by _generate_configs.py\n")
yaml.dump(
{
"include": "_default_yaml",
"task": f"global_mmlu_{language}",
"dataset_name": language,
},
f,
)
except FileExistsError:
pass
if __name__ == "__main__":
main()
tag:
- global_mmlu
dataset_path: CohereForAI/Global-MMLU-Lite dataset_path: CohereForAI/Global-MMLU-Lite
dataset_name: ar
test_split: test test_split: test
fewshot_split: dev fewshot_split: dev
fewshot_config: fewshot_config:
......
group: global_mmlu_ar
task:
- global_mmlu_ar_business
- global_mmlu_ar_humanities
- global_mmlu_ar_medical
- global_mmlu_ar_other
- global_mmlu_ar_stem
- global_mmlu_ar_social_sciences
aggregate_metric_list:
- metric: acc
weight_by_size: True
metadata:
version: 0.0
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_business
task: global_mmlu_ar_business
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_humanities
task: global_mmlu_ar_humanities
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_medical
task: global_mmlu_ar_medical
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_other
task: global_mmlu_ar_other
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_social_sciences
task: global_mmlu_ar_social_sciences
# Generated by _generate_configs.py
include: _ar_template_yaml
process_docs: !function utils.process_stem
task: global_mmlu_ar_stem
from functools import partial
CATEGORIES = ["Business", "Humanities", "Medical", "Other", "STEM", "Social Sciences"]
def process_docs(dataset, category):
return dataset.filter(lambda x: x["subject_category"] == category)
process_functions = {
f"process_{category.lower().replace(' ', '_')}": partial(
process_docs, category=category
)
for category in CATEGORIES
}
globals().update(process_functions)
dataset_path: CohereForAI/Global-MMLU-Lite
dataset_name: bn
test_split: test
fewshot_split: dev
fewshot_config:
sampler: default
output_type: multiple_choice
doc_to_text: "{{question.strip()}}\nA. {{option_a}}\nB. {{option_b}}\nC. {{option_c}}\nD. {{option_d}}\nAnswer:"
doc_to_choice: ["A", "B", "C", "D"]
doc_to_target: answer
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
metadata:
version: 0.0
group: global_mmlu_bn
task:
- global_mmlu_bn_business
- global_mmlu_bn_humanities
- global_mmlu_bn_medical
- global_mmlu_bn_other
- global_mmlu_bn_stem
- global_mmlu_bn_social_sciences
aggregate_metric_list:
- metric: acc
weight_by_size: True
metadata:
version: 0.0
# Generated by _generate_configs.py
include: _bn_template_yaml
process_docs: !function utils.process_business
task: global_mmlu_bn_business
# Generated by _generate_configs.py
include: _bn_template_yaml
process_docs: !function utils.process_humanities
task: global_mmlu_bn_humanities
# Generated by _generate_configs.py
include: _bn_template_yaml
process_docs: !function utils.process_medical
task: global_mmlu_bn_medical
# Generated by _generate_configs.py
include: _bn_template_yaml
process_docs: !function utils.process_other
task: global_mmlu_bn_other
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment