Commit 7385e116 authored by lintang's avatar lintang
Browse files

changes for multimodal runs

parent 8138fd52
...@@ -21,6 +21,7 @@ from tqdm import tqdm ...@@ -21,6 +21,7 @@ from tqdm import tqdm
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
) )
from lm_eval import utils from lm_eval import utils
...@@ -37,6 +38,10 @@ from lm_eval.models.utils import ( ...@@ -37,6 +38,10 @@ from lm_eval.models.utils import (
) )
import llava
import cambrian
import palo
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -566,6 +571,11 @@ class HFLM(TemplateLM): ...@@ -566,6 +571,11 @@ class HFLM(TemplateLM):
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
): ):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
getattr(self.config, "model_type")
in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
):
self.AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
else: else:
if not trust_remote_code: if not trust_remote_code:
eval_logger.warning( eval_logger.warning(
...@@ -579,6 +589,7 @@ class HFLM(TemplateLM): ...@@ -579,6 +589,7 @@ class HFLM(TemplateLM):
assert self.AUTO_MODEL_CLASS in [ assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM, transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM, transformers.AutoModelForSeq2SeqLM,
transformers.AutoModelForVision2Seq,
] ]
return None return None
...@@ -837,7 +848,7 @@ class HFLM(TemplateLM): ...@@ -837,7 +848,7 @@ class HFLM(TemplateLM):
# by default for CausalLM - false or self.add_bos_token is set # by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if (self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM) or (self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq):
special_tokens_kwargs = { special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token "add_special_tokens": False or self.add_bos_token
} }
...@@ -865,7 +876,7 @@ class HFLM(TemplateLM): ...@@ -865,7 +876,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = {} add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if (self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM) or (self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq):
add_special_tokens = {"add_special_tokens": False or self.add_bos_token} add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.tokenizer( encoding = self.tokenizer(
...@@ -910,8 +921,11 @@ class HFLM(TemplateLM): ...@@ -910,8 +921,11 @@ class HFLM(TemplateLM):
input_ids=inps, attention_mask=attn_mask, labels=labels input_ids=inps, attention_mask=attn_mask, labels=labels
).logits ).logits
else: else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM assert (
return self.model(inps).logits self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
or self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq
)
return self.model(input_ids=inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# temperature = 0.0 if not set # temperature = 0.0 if not set
...@@ -943,20 +957,21 @@ class HFLM(TemplateLM): ...@@ -943,20 +957,21 @@ class HFLM(TemplateLM):
def _select_cont_toks( def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
assert ( assert (
contlen and not inplen contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len" ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding. # only discard right-padding.
# the logits input to this fn only contain decoder-side tokens. # the logits input to this fn only contain decoder-side tokens.
logits = logits[:contlen] logits = logits[:contlen]
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
else:
assert (
contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen]
return logits return logits
...@@ -1073,7 +1088,7 @@ class HFLM(TemplateLM): ...@@ -1073,7 +1088,7 @@ class HFLM(TemplateLM):
requests, requests,
sort_fn=_collate, sort_fn=_collate,
group_by="contexts" group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if (self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM) or (self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq)
and self.logits_cache and self.logits_cache
else None, else None,
group_fn=_lookup_one_token_cont, group_fn=_lookup_one_token_cont,
...@@ -1131,14 +1146,7 @@ class HFLM(TemplateLM): ...@@ -1131,14 +1146,7 @@ class HFLM(TemplateLM):
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
...@@ -1165,6 +1173,14 @@ class HFLM(TemplateLM): ...@@ -1165,6 +1173,14 @@ class HFLM(TemplateLM):
if padding_len_cont is not None if padding_len_cont is not None
else contlen else contlen
) )
else:
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
padding_len_inp = ( padding_len_inp = (
max(padding_len_inp, inplen) max(padding_len_inp, inplen)
...@@ -1178,11 +1194,7 @@ class HFLM(TemplateLM): ...@@ -1178,11 +1194,7 @@ class HFLM(TemplateLM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps padding_len_inp, inps
...@@ -1197,6 +1209,11 @@ class HFLM(TemplateLM): ...@@ -1197,6 +1209,11 @@ class HFLM(TemplateLM):
"attn_mask": batched_encoder_mask, "attn_mask": batched_encoder_mask,
"labels": batched_conts, "labels": batched_conts,
} }
else:
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1 self._model_call(batched_inps, **call_kwargs), dim=-1
...@@ -1213,7 +1230,7 @@ class HFLM(TemplateLM): ...@@ -1213,7 +1230,7 @@ class HFLM(TemplateLM):
# from prompt/prefix tuning tokens, if applicable # from prompt/prefix tuning tokens, if applicable
ctx_len = ( ctx_len = (
inplen + (logits.shape[0] - padding_len_inp) inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if (self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM) or (self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq)
else None else None
) )
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
...@@ -1348,12 +1365,13 @@ class HFLM(TemplateLM): ...@@ -1348,12 +1365,13 @@ class HFLM(TemplateLM):
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
else:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
# encode, pad, and truncate contexts for this batch # encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode( context_enc, attn_masks = self.tok_batch_encode(
...@@ -1378,7 +1396,7 @@ class HFLM(TemplateLM): ...@@ -1378,7 +1396,7 @@ class HFLM(TemplateLM):
cont_toks_list = cont.tolist() cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts): for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM # discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if (self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM) or (self.AUTO_MODEL_CLASS == transformers.AutoModelForVision2Seq):
cont_toks = cont_toks[context_enc.shape[1] :] cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks) s = self.tok_decode(cont_toks)
......
# This file will be included in the generated language-specific task configs. # This file will be included in the generated language-specific task configs.
# It doesn't have a yaml file extension as it is not meant to be imported directly # It doesn't have a yaml file extension as it is not meant to be imported directly
# by the harness. # by the harness.
group: mgsm_cot_native group: mgsm_cot_en
dataset_path: juletxara/mgsm dataset_path: juletxara/mgsm
dataset_name: null # Overridden by language-specific config. dataset_name: null # Overridden by language-specific config.
output_type: generate_until output_type: generate_until
......
# DROP
### Paper
Title: `TyDi QA: A Benchmark for Information-Seeking Question Answering in Typologically Diverse Languages`
Abstract: https://arxiv.org/pdf/2003.05002
Confidently making progress on multilingual modeling requires challenging, trustworthy evaluations. We present TyDi QA---a question answering dataset covering 11 typologically diverse languages with 204K question-answer pairs. The languages of TyDi QA are diverse with regard to their typology---the set of linguistic features each language expresses---such that we expect models performing well on this set to generalize across a large number of the world's languages. We present a quantitative analysis of the data quality and example-level qualitative linguistic analyses of observed language phenomena that would not be found in English-only corpora. To provide a realistic information-seeking task and avoid priming effects, questions are written by people who want to know the answer, but don't know the answer yet, and the data is collected directly in each language without the use of translation.
Homepage: https://ai.google.com/research/tydiqa/
Acknowledgement: This implementation is based on the official evaluation for `TyDiQA Gold`:
https://github.com/google-research-datasets/tydiqa/blob/master/tydi_eval.py
### Citation
```
@article{tydiqa,
title = {TyDi QA: A Benchmark for Information-Seeking Question Answering in Typologically Diverse Languages},
author = {Jonathan H. Clark and Eunsol Choi and Michael Collins and Dan Garrette and Tom Kwiatkowski and Vitaly Nikolaev and Jennimaria Palomaki}
year = {2020},
journal = {Transactions of the Association for Computational Linguistics}
}
```
### Groups and Tasks
#### Groups
*
#### Tasks
* `tydiqa`
### Checklist
For adding novel benchmarks/datasets to the library:
* [ ] Is the task an existing benchmark in the literature?
* [ ] Have you referenced the original paper that introduced the task?
* [ ] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?
If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
group: tydiqa
task:
- tydiqa_arabic
- tydiqa_bengali
- tydiqa_english
- tydiqa_finnish
- tydiqa_finnish
- tydiqa_indonesian
- tydiqa_korean
- tydiqa_russian
- tydiqa_swahili
- tydiqa_telugu
aggregate_metric_list:
- metric: f1
aggregation: mean
weight_by_size: False
metadata:
version: 1.0
\ No newline at end of file
task: tydiqa_arabic
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_arabic
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_bengali
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_bengali
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_english
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_english
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_finnish
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_finnish
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_indonesian
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_indonesian
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_korean
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_korean
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_russian
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_russian
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_swahili
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_swahili
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
task: tydiqa_telugu
dataset_path: google-research-datasets/tydiqa
dataset_name: secondary_task
output_type: generate_until
process_docs: !function util.filter_telugu
training_split: train
validation_split: validation
doc_to_text: "Answer the following question based on Context. You can extract an answer span from Context.\nContext: {{context}} Question: {{question}}?\nAnswer:"
doc_to_target: !function util.doc_to_target
process_results: !function util.process_results
num_fewshot: 4
generation_kwargs:
until:
- "."
- "\n"
filter_list:
- name: remove_whitespace
filter:
- function: remove_whitespace
metric_list:
- metric: em
aggregation: mean
higher_is_better: true
- metric: f1
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
import datasets
import numpy as np
import transformers.data.metrics.squad_metrics as squad_metrics
from lm_eval.api.metrics import metric_max_over_ground_truths
def doc_to_target(doc):
return doc["answers"]["text"][0]
def filter_arabic(dataset):
return dataset.filter(lambda example: example["id"].startswith("arabic"))
def filter_bengali(dataset):
return dataset.filter(lambda example: example["id"].startswith("bengali"))
def filter_finnish(dataset):
return dataset.filter(lambda example: example["id"].startswith("finnish"))
def filter_indonesian(dataset):
return dataset.filter(lambda example: example["id"].startswith("indonesian"))
def filter_russian(dataset):
return dataset.filter(lambda example: example["id"].startswith("russian"))
def filter_korean(dataset):
return dataset.filter(lambda example: example["id"].startswith("korean"))
def filter_english(dataset):
return dataset.filter(lambda example: example["id"].startswith("english"))
def filter_swahili(dataset):
return dataset.filter(lambda example: example["id"].startswith("swahili"))
def filter_telugu(dataset):
return dataset.filter(lambda example: example["id"].startswith("telugu"))
def process_results(doc, results):
gold_label_set = doc["answers"]["text"]
f1 = metric_max_over_ground_truths(
squad_metrics.compute_f1, results[0][0], gold_label_set
)
em = metric_max_over_ground_truths(
squad_metrics.compute_exact, results[0][0], gold_label_set
)
return {
"f1": f1,
"em": em,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment