Commit 9b9ca7bf authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

modeling cleanup

parent 9692aa05
import copy import copy
from typing import List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch
import transformers import transformers
from tqdm import tqdm from tqdm import tqdm
...@@ -102,51 +103,58 @@ class HFMultimodalLM(HFLM): ...@@ -102,51 +103,58 @@ class HFMultimodalLM(HFLM):
# return encoding # return encoding
# def tok_batch_encode( def tok_batch_encode(
# self, self,
# strings: List[str], strings: List[str], # note that input signature of this fn is different
# padding_side: str = "left", visuals, # TODO: typehint on this
# left_truncate_len: int = None, padding_side: str = "left",
# truncation: bool = False, left_truncate_len: int = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]: truncation: bool = False,
# # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode. ) -> Dict[
# old_padding_side = self.tokenizer.padding_side str, torch.Tensor
# self.tokenizer.padding_side = padding_side ]: # TODO: note that this return signature differs from HFLM tok_batch_encode.
# TODO: we should allow
# add_special_tokens = {}
# if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
# add_special_tokens = {"add_special_tokens": False or self.add_bos_token} old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
# encoding = self.tokenizer(
# strings, add_special_tokens = {}
# truncation=truncation, if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# padding="longest", add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
# return_tensors="pt",
# **add_special_tokens, encoding = self.processor(
# ) strings,
# if left_truncate_len: truncation=truncation,
# encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] padding="longest",
# encoding["attention_mask"] = encoding["attention_mask"][ return_tensors="pt",
# :, -left_truncate_len: **add_special_tokens,
# ] ).to(
# self.tokenizer.padding_side = old_padding_side self.device, self.model.dtype
) # TODO: casting to dtype seems odd for input_ids and attn_mask.
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
# return encoding["input_ids"], encoding["attention_mask"] return encoding
# def tok_decode(self, tokens, skip_special_tokens=True): # def tok_decode(self, tokens, skip_special_tokens=True):
# return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) # return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
def _model_generate(self, inputs, stop, **gen_kwargs): def _model_generate(self, inputs, max_length, stop, **generation_kwargs):
# TODO: handle max_length
# gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))] # gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
if "max_new_tokens" not in gen_kwargs: generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
gen_kwargs["max_new_tokens"] = 1024 do_sample = generation_kwargs.get("do_sample", None)
if "temperature" not in gen_kwargs:
gen_kwargs["temperature"] = 0 # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if "top_p" not in gen_kwargs: if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
gen_kwargs["top_p"] = None generation_kwargs["do_sample"] = do_sample = False
if "num_beams" not in gen_kwargs:
gen_kwargs["num_beams"] = 1 if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
self.tokenizer, self.tokenizer,
...@@ -156,15 +164,11 @@ class HFMultimodalLM(HFLM): ...@@ -156,15 +164,11 @@ class HFMultimodalLM(HFLM):
) )
return self.model.generate( return self.model.generate(
**inputs, **inputs,
# max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
do_sample=True if gen_kwargs["temperature"] > 0 else False, pad_token_id=self.tokenizer.pad_token_id,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=True, use_cache=True,
pad_token_id=self.tokenizer.eos_token_id, **generation_kwargs,
) )
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
...@@ -257,21 +261,19 @@ class HFMultimodalLM(HFLM): ...@@ -257,21 +261,19 @@ class HFMultimodalLM(HFLM):
max_ctx_len = self.max_length - max_gen_toks # noqa: F841 # TODO: this assumes we are using a causal LM. is that always valid? shouldn't be max_ctx_len = self.max_length - max_gen_toks # noqa: F841 # TODO: this assumes we are using a causal LM. is that always valid? shouldn't be
self.tokenizer.padding_side = "left" inputs = self.tok_batch_encode(
inputs = self.processor( # TODO: write this as tok_batch_encode (and allow that to either take a visuals value or None) contexts,
images=visuals, text=contexts, return_tensors="pt", padding=True visuals,
).to( left_truncate_len=max_ctx_len,
self.device, self.model.dtype truncation=self.truncation,
) # TODO: factor out into a tok_batch_encode bit ; truncate from left using max_ctx_len ).to(self.device, self.model.dtype)
print(inputs)
context_enc = inputs["input_ids"] context_enc = inputs["input_ids"]
if "max_length" not in kwargs: if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_generate(inputs, stop=until, **gen_kwargs) cont = self._model_generate(inputs, stop=until, **kwargs)
### essentially same as HFLM beyond this line! ### essentially same as HFLM beyond this line!
......
import collections import collections
import inspect
import logging import logging
import os import os
from functools import partial from functools import partial
from typing import Dict, List, Mapping, Optional, Union from typing import Dict, List, Mapping, Optional, Union
from lm_eval import utils from lm_eval import utils
from lm_eval.api.group import ConfigurableGroup, GroupConfig
from lm_eval.api.task import ConfigurableTask, Task from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.evaluator_utils import get_subtask_list
GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys())
class TaskManager: class TaskManager:
...@@ -86,12 +80,7 @@ class TaskManager: ...@@ -86,12 +80,7 @@ class TaskManager:
return False return False
def _name_is_task(self, name) -> bool: def _name_is_task(self, name) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): if self._name_is_registered(name) and ("task" in self.task_index[name]["type"]):
return True
return False
def _name_is_tag(self, name) -> bool:
if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"):
return True return True
return False return False
...@@ -152,126 +141,89 @@ class TaskManager: ...@@ -152,126 +141,89 @@ class TaskManager:
config["group_alias"] = None config["group_alias"] = None
return config return config
def _class_has_config_in_constructor(self, cls):
constructor = getattr(cls, "__init__", None)
return (
"config" in inspect.signature(constructor).parameters
if constructor
else False
)
def _load_individual_task_or_group( def _load_individual_task_or_group(
self, self,
name_or_config: Optional[Union[str, dict]] = None, name_or_config: Optional[Union[str, dict]] = None,
parent_name: Optional[str] = None, parent_name: Optional[str] = None,
update_config: Optional[dict] = None, update_config: Optional[dict] = None,
yaml_path: Optional[str] = None,
) -> Mapping: ) -> Mapping:
def _load_task(config, task): def load_task(config, task, group=None, yaml_path=None):
if "include" in config: if "include" in config:
if yaml_path is None:
raise ValueError
config = { config = {
**utils.load_yaml_config( **utils.load_yaml_config(
yaml_path=None, yaml_path,
yaml_config={"include": config.pop("include")}, yaml_config={"include": config.pop("include")},
mode="full", mode="full",
), ),
**config, **config,
} }
if self._config_is_python_task(config): if self._config_is_python_task(config):
if self._class_has_config_in_constructor(config["class"]): task_object = config["class"]()
task_object = config["class"](config=config)
else:
task_object = config["class"]()
if isinstance(task_object, ConfigurableTask):
# very scuffed: set task name here. TODO: fixme?
task_object.config.task = config["task"]
else: else:
config = self._process_alias(config, group=group)
task_object = ConfigurableTask(config=config) task_object = ConfigurableTask(config=config)
if group is not None:
task_object = (group, task_object)
return {task: task_object} return {task: task_object}
def _get_group_and_subtask_from_config(config):
group_name = ConfigurableGroup(config=config)
subtask_list = []
for task in group_name.config["task"]:
if isinstance(task, str) and self._name_is_tag(task):
subtask_list.extend(self._get_tasklist(task))
else:
subtask_list.append(task)
return group_name, subtask_list
def _process_group_config(config, update_config=None):
if update_config is not None:
config = {**config, **update_config}
_update_config = {
k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS
}
if not bool(_update_config):
_update_config = None
group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS}
return group_config, _update_config
if isinstance(name_or_config, str): if isinstance(name_or_config, str):
if update_config is not None: if update_config is not None:
# Process name_or_config as a dict instead # Process name_or_config as a dict instead
name_or_config = {"task": name_or_config, **update_config} name_or_config = {"task": name_or_config, **update_config}
elif self._name_is_task(name_or_config) or self._name_is_python_task( elif self._name_is_task(name_or_config):
name_or_config
):
task_config = self._get_config(name_or_config) task_config = self._get_config(name_or_config)
return _load_task(task_config, task=name_or_config) return load_task(task_config, task=name_or_config, group=parent_name)
else: else:
group_name = name_or_config
subtask_list = self._get_tasklist(name_or_config) subtask_list = self._get_tasklist(name_or_config)
if subtask_list == -1: if subtask_list == -1:
group_config = self._get_config(name_or_config) group_config = self._get_config(name_or_config)
group_config, update_config = _process_group_config(group_config) subtask_list = group_config["task"]
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config # This checks if we're at the root.
) if parent_name is None:
else: group_config = self._get_config(name_or_config)
if self._name_is_tag(name_or_config): if set(group_config.keys()) > {"task", "group"}:
fn = partial( update_config = {
self._load_individual_task_or_group, k: v
update_config=name_or_config for k, v in group_config.items()
if isinstance(name_or_config, dict) if k not in ["task", "group"]
else None, }
) yaml_path = self._get_yaml_path(group_name)
return dict(
collections.ChainMap(*map(fn, reversed(subtask_list))) if (update_config is not None) and ("group_alias" in update_config):
) group_name = update_config["group_alias"]
else: update_config.pop("group_alias")
group_name = ConfigurableGroup(
config={"group": name_or_config, "task": subtask_list}
)
if isinstance(name_or_config, dict): if isinstance(name_or_config, dict):
if update_config is not None:
name_or_config = {
**name_or_config,
**update_config,
}
if self._config_is_task(name_or_config): if self._config_is_task(name_or_config):
name = name_or_config.pop("task") name = name_or_config["task"]
if update_config is not None:
name_or_config = {**name_or_config, **update_config}
# If the name is registered as a group # If the name is registered as a group
# if self._name_is_task(name) is False:
if self._name_is_group(name): if self._name_is_group(name):
group_config = self._get_config(name) group_name = name
update_config = {
group_config, update_config = _process_group_config( k: v for k, v in name_or_config.items() if k != "task"
group_config, name_or_config }
)
group_name, subtask_list = _get_group_and_subtask_from_config(
group_config
)
elif self._name_is_tag(name):
subtask_list = self._get_tasklist(name) subtask_list = self._get_tasklist(name)
fn = partial( if subtask_list == -1:
self._load_individual_task_or_group, subtask_list = self._get_config(name)["task"]
update_config=name_or_config,
)
return dict(collections.ChainMap(*map(fn, reversed(subtask_list))))
else: else:
if self._name_is_registered(name): if self._name_is_registered(name):
base_task_config = self._get_config(name) base_task_config = self._get_config(name)
# Check if this is a duplicate. # Check if this is a duplicate.
if parent_name is not None: if parent_name is not None:
name_or_config["group"] = parent_name
num_duplicate = len( num_duplicate = len(
list( list(
filter( filter(
...@@ -290,21 +242,34 @@ class TaskManager: ...@@ -290,21 +242,34 @@ class TaskManager:
} }
else: else:
task_config = name_or_config task_config = name_or_config
return _load_task(task_config, task=name) return load_task(
task_config, task=name, group=parent_name, yaml_path=yaml_path
)
else: else:
group_config, update_config = _process_group_config(name_or_config) group_name = name_or_config["group"]
group_name, subtask_list = _get_group_and_subtask_from_config( subtask_list = name_or_config["task"]
group_config if set(name_or_config.keys()) > {"task", "group"}:
) update_config = {
k: v
for k, v in name_or_config.items()
if k not in ["task", "group"]
}
all_subtasks = {}
if parent_name is not None:
all_subtasks = {group_name: (parent_name, None)}
fn = partial( fn = partial(
self._load_individual_task_or_group, self._load_individual_task_or_group,
parent_name=group_name, parent_name=group_name,
update_config=update_config, update_config=update_config,
yaml_path=yaml_path,
) )
return { all_subtasks = {
group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) **all_subtasks,
**dict(collections.ChainMap(*map(fn, subtask_list))),
} }
return all_subtasks
def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict: def load_task_or_group(self, task_list: Optional[Union[str, list]] = None) -> dict:
"""Loads a dictionary of task objects from a list """Loads a dictionary of task objects from a list
...@@ -328,11 +293,10 @@ class TaskManager: ...@@ -328,11 +293,10 @@ class TaskManager:
def _get_task_and_group(self, task_dir: str): def _get_task_and_group(self, task_dir: str):
"""Creates a dictionary of tasks index with the following metadata, """Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`. - `type`, that can be either `task`, `python_task`, or `group`.
`task` refer to regular task configs, `python_task` are special `task` refer to regular task configs, `python_task` are special
yaml files that only consists of `task` and `class` parameters. yaml files that only consists of `task` and `class` parameters.
`group` are group configs. `tags` are labels that can be assigned `group` are group configs.
to tasks to assist in sorting and calling tasks of certain themes.
- `yaml_path`, path to the yaml file. If the entry is a `group` that - `yaml_path`, path to the yaml file. If the entry is a `group` that
was configured through a task config, the yaml_path will be -1 was configured through a task config, the yaml_path will be -1
and all subtasks will be listed in `task` (see below) and all subtasks will be listed in `task` (see below)
...@@ -348,8 +312,6 @@ class TaskManager: ...@@ -348,8 +312,6 @@ class TaskManager:
:return :return
Dictionary of task names as key and task metadata Dictionary of task names as key and task metadata
""" """
# TODO: remove group in next release
print_info = True
ignore_dirs = [ ignore_dirs = [
"__pycache__", "__pycache__",
".ipynb_checkpoints", ".ipynb_checkpoints",
...@@ -396,38 +358,20 @@ class TaskManager: ...@@ -396,38 +358,20 @@ class TaskManager:
"yaml_path": yaml_path, "yaml_path": yaml_path,
} }
# TODO: remove group in next release if "group" in config:
for attr in ["tag", "group"]: groups = config["group"]
if attr in config: if isinstance(config["group"], str):
if attr == "group" and print_info: groups = [groups]
self.logger.info(
"`group` and `group_alias` keys in tasks' configs will no longer be used in the next release of lm-eval. " for group in groups:
"`tag` will be used to allow to call a collection of tasks just like `group`. " if group not in tasks_and_groups:
"`group` will be removed in order to not cause confusion with the new ConfigurableGroup " tasks_and_groups[group] = {
"which will be the offical way to create groups with addition of group-wide configuations." "type": "group",
) "task": [task],
print_info = False "yaml_path": -1,
# attr = "tag" }
else:
attr_list = config[attr] tasks_and_groups[group]["task"].append(task)
if isinstance(attr_list, str):
attr_list = [attr_list]
for tag in attr_list:
if tag not in tasks_and_groups:
tasks_and_groups[tag] = {
"type": "tag",
"task": [task],
"yaml_path": -1,
}
elif tasks_and_groups[tag]["type"] != "tag":
self.logger.info(
f"The tag {tag} is already registered as a group, this tag will not be registered. "
"This may affect tasks you want to call."
)
break
else:
tasks_and_groups[tag]["task"].append(task)
else: else:
self.logger.debug(f"File {f} in {root} could not be loaded") self.logger.debug(f"File {f} in {root} could not be loaded")
...@@ -456,33 +400,6 @@ def get_task_name_from_object(task_object): ...@@ -456,33 +400,6 @@ def get_task_name_from_object(task_object):
) )
def _check_duplicates(task_dict: dict) -> List[str]:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
"oversubscribed" to several disjoint groups.
"""
subtask_names = []
for key, value in task_dict.items():
subtask_names.extend(value)
duplicate_tasks = {
task_name for task_name in subtask_names if subtask_names.count(task_name) > 1
}
# locate the potentially problematic groups that seem to 'compete' for constituent subtasks
competing_groups = [
group
for group in task_dict.keys()
if len(set(task_dict[group]).intersection(duplicate_tasks)) > 0
]
if len(duplicate_tasks) > 0:
raise ValueError(
f"Found 1 or more tasks while trying to call get_task_dict() that were members of more than 1 called group: {list(duplicate_tasks)}. Offending groups: {competing_groups}. Please call groups which overlap their constituent tasks in separate evaluation runs."
)
def get_task_dict( def get_task_dict(
task_name_list: Union[str, List[Union[str, Dict, Task]]], task_name_list: Union[str, List[Union[str, Dict, Task]]],
task_manager: Optional[TaskManager] = None, task_manager: Optional[TaskManager] = None,
...@@ -500,7 +417,6 @@ def get_task_dict( ...@@ -500,7 +417,6 @@ def get_task_dict(
:return :return
Dictionary of task objects Dictionary of task objects
""" """
task_name_from_string_dict = {} task_name_from_string_dict = {}
task_name_from_config_dict = {} task_name_from_config_dict = {}
task_name_from_object_dict = {} task_name_from_object_dict = {}
...@@ -547,16 +463,8 @@ def get_task_dict( ...@@ -547,16 +463,8 @@ def get_task_dict(
): ):
raise ValueError raise ValueError
final_task_dict = { return {
**task_name_from_string_dict, **task_name_from_string_dict,
**task_name_from_config_dict, **task_name_from_config_dict,
**task_name_from_object_dict, **task_name_from_object_dict,
} }
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
# and we'd be unsure which to use and report.)
# we explicitly check and error in this case.
_check_duplicates(get_subtask_list(final_task_dict))
return final_task_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment