Commit 9b9ca7bf authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

modeling cleanup

parent 9692aa05
import copy
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import torch
import transformers
from tqdm import tqdm
......@@ -102,51 +103,58 @@ class HFMultimodalLM(HFLM):
# return encoding
# def tok_batch_encode(
# self,
# strings: List[str],
# padding_side: str = "left",
# left_truncate_len: int = None,
# truncation: bool = False,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
# old_padding_side = self.tokenizer.padding_side
# self.tokenizer.padding_side = padding_side
# add_special_tokens = {}
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
# encoding = self.tokenizer(
# strings,
# truncation=truncation,
# padding="longest",
# return_tensors="pt",
# **add_special_tokens,
# )
# if left_truncate_len:
# encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
# encoding["attention_mask"] = encoding["attention_mask"][
# :, -left_truncate_len:
# ]
# self.tokenizer.padding_side = old_padding_side
def tok_batch_encode(
self,
strings: List[str], # note that input signature of this fn is different
visuals, # TODO: typehint on this
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Dict[
str, torch.Tensor
]: # TODO: note that this return signature differs from HFLM tok_batch_encode.
# TODO: we should allow
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.processor(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
**add_special_tokens,
).to(
self.device, self.model.dtype
) # TODO: casting to dtype seems odd for input_ids and attn_mask.
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
# return encoding["input_ids"], encoding["attention_mask"]
return encoding
# def tok_decode(self, tokens, skip_special_tokens=True):
# return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_generate(self, inputs, stop, **gen_kwargs):
# TODO: handle max_length
def _model_generate(self, inputs, max_length, stop, **generation_kwargs):
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs:
gen_kwargs["max_new_tokens"] = 1024
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0
if "top_p" not in gen_kwargs:
gen_kwargs["top_p"] = None
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
......@@ -156,15 +164,11 @@ class HFMultimodalLM(HFLM):
)
return self.model.generate(
**inputs,
# max_length=max_length,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
pad_token_id=self.tokenizer.eos_token_id,
**generation_kwargs,
)
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
......@@ -257,21 +261,19 @@ class HFMultimodalLM(HFLM):
max_ctx_len = self.max_length - max_gen_toks # noqa: F841 # TODO: this assumes we are using a causal LM. is that always valid? shouldn't be
self.tokenizer.padding_side = "left"
inputs = self.processor( # TODO: write this as tok_batch_encode (and allow that to either take a visuals value or None)
images=visuals, text=contexts, return_tensors="pt", padding=True
).to(
self.device, self.model.dtype
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len
print(inputs)
inputs = self.tok_batch_encode(
contexts,
visuals,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
).to(self.device, self.model.dtype)
context_enc = inputs["input_ids"]
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_generate(inputs, stop=until, **gen_kwargs)
cont = self._model_generate(inputs, stop=until, **kwargs)
### essentially same as HFLM beyond this line!
......
import collections
import inspect
import logging
import os
from functools import partial
from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
class TaskManager:
......@@ -86,12 +80,7 @@ class TaskManager:
return False
def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"):
return True
return False
def _name_is_tag(self, name) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True
return False
......@@ -152,126 +141,89 @@ class TaskManager:
config["group_alias"] = None
return config
def _class_has_config_in_constructor(self, cls):
constructor = getattr(cls, "__init__", None)
return (
"config" in inspect.signature(constructor).parameters
if constructor
else False
)
def _load_individual_task_or_group(
self,
name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None,
update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping:
def _load_task(config, task):
def load_task(config, task, group=None, yaml_path=None):
if "include" in config:
if yaml_path is None:
raise ValueError
config = {
**utils.load_yaml_config(
yaml_path=None,
yaml_path,
yaml_config={"include": config.pop("include")},
mode="full",
),
**config,
}
if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]):
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
task_object = config["class"]()
else:
config = self._process_alias(config, group=group)
task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object}
def _get_group_and_subtask_from_config(config):
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task))
else:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(config, update_config=None):
if update_config is not None:
config = {**config, **update_config}
_update_config = {
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
}
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str):
if update_config is not None:
# Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task(
name_or_config
):
elif self._name_is_task(name_or_config):
task_config = self._get_config(name_or_config)
return _load_task(task_config, task=name_or_config)
return load_task(task_config, task=name_or_config, group=parent_name)
else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1:
group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
else:
if self._name_is_tag(name_or_config):
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config
if isinstance(name_or_config, dict)
else None,
)
return dict(
collections.ChainMap(*map(fn, reversed(subtask_list)))
)
else:
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
)
subtask_list = group_config["task"]
# This checks if we're at the root.
if parent_name is None:
group_config = self._get_config(name_or_config)
if set(group_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in group_config.items()
if k not in ["task", "group"]
}
yaml_path = self._get_yaml_path(group_name)
if (update_config is not None) and ("group_alias" in update_config):
group_name = update_config["group_alias"]
update_config.pop("group_alias")
if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config = {
**name_or_config,
**update_config,
}
if self._config_is_task(name_or_config):
name = name_or_config.pop("task")
if update_config is not None:
name_or_config = {**name_or_config, **update_config}
name = name_or_config["task"]
# If the name is registered as a group
# if self._name_is_task(name) is False:
if self._name_is_group(name):
group_config = self._get_config(name)
group_config, update_config = _process_group_config(
group_config, name_or_config
)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
elif self._name_is_tag(name):
group_name = name
update_config = {
k: v for k, v in name_or_config.items() if k != "task"
}
subtask_list = self._get_tasklist(name)
fn = partial(
self._load_individual_task_or_group,
update_config=name_or_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
if subtask_list == -1:
subtask_list = self._get_config(name)["task"]
else:
if self._name_is_registered(name):
base_task_config = self._get_config(name)
# Check if this is a duplicate.
if parent_name is not None:
name_or_config["group"] = parent_name
num_duplicate = len(
list(
filter(
......@@ -290,21 +242,34 @@ class TaskManager:
}
else:
task_config = name_or_config
return _load_task(task_config, task=name)
return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else:
group_config, update_config = _process_group_config(name_or_config)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
group_name = name_or_config["group"]
subtask_list = name_or_config["task"]
if set(name_or_config.keys()) > {"task", "group"}:
update_config = {
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
}
all_subtasks = {}
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}
fn = partial(
self._load_individual_task_or_group,
parent_name=group_name,
update_config=update_config,
yaml_path=yaml_path,
)
return {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
all_subtasks = {
**all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
}
return all_subtasks
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list
......@@ -328,11 +293,10 @@ class TaskManager:
def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
- `type`, that can be either `task`, `python_task`, or `group`.
`task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters.
`group` are group configs. `tags` are labels that can be assigned
to tasks to assist in sorting and calling tasks of certain themes.
`group` are group configs.
- `yaml_path`, path to the yaml file. If the entry is a `group` that
was configured through a task config, the yaml_path will be -1
and all subtasks will be listed in `task` (see below)
......@@ -348,8 +312,6 @@ class TaskManager:
:return
Dictionary of task names as key and task metadata
"""
# TODO: remove group in next release
print_info = True
ignore_dirs = [
"__pycache__",
".ipynb_checkpoints",
......@@ -396,38 +358,20 @@ class TaskManager:
"yaml_path": yaml_path,
}
# TODO: remove group in next release
for attr in ["tag", "group"]:
if attr in config:
if attr == "group" and print_info:
self.logger.info(
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. "
"`tag` will be used to allow to call a collection of tasks just like `group`. "
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup "
"which will be the offical way to create groups with addition of group-wide configuations."
)
print_info = False
# attr = "tag"
attr_list = config[attr]
if isinstance(attr_list, str):
attr_list = [attr_list]
for tag in attr_list:
if tag not in tasks_and_groups:
tasks_and_groups[tag] = {
"type": "tag",
"task": [task],
"yaml_path": -1,
}
elif tasks_and_groups[tag]["type"] != "tag":
self.logger.info(
f"The tag {tag} is already registered as a group, this tag will not be registered. "
"This may affect tasks you want to call."
)
break
else:
tasks_and_groups[tag]["task"].append(task)
if "group" in config:
groups = config["group"]
if isinstance(config["group"], str):
groups = [groups]
for group in groups:
if group not in tasks_and_groups:
tasks_and_groups[group] = {
"type": "group",
"task": [task],
"yaml_path": -1,
}
else:
tasks_and_groups[group]["task"].append(task)
else:
self.logger.debug(f"File {f} in {root} could not be loaded")
......@@ -456,33 +400,6 @@ def get_task_name_from_object(task_object):
)
def _check_duplicates(task_dict: dict) -> List[str]:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
"oversubscribed" to several disjoint groups.
"""
subtask_names = []
for key, value in task_dict.items():
subtask_names.extend(value)
duplicate_tasks = {
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
}
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [
group
for group in task_dict.keys()
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
]
if len(duplicate_tasks) > 0:
raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
)
def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None,
......@@ -500,7 +417,6 @@ def get_task_dict(
:return
Dictionary of task objects
"""
task_name_from_string_dict = {}
task_name_from_config_dict = {}
task_name_from_object_dict = {}
......@@ -547,16 +463,8 @@ def get_task_dict(
):
raise ValueError
final_task_dict = {
return {
**task_name_from_string_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment