Commit eddd627a authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

add utility func for slicing iterators, only tqdm on main process

parent 629bcfba
...@@ -260,8 +260,7 @@ class Task(abc.ABC): ...@@ -260,8 +260,7 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" ), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
instances = [] instances = []
# for doc_id, doc in enumerate(itertools.islice(docs, 0, limit) if limit else docs): for doc_id, doc in utils.create_iterator(enumerate(docs), rank, world_size, limit):
for doc_id, doc in itertools.islice(enumerate(docs), rank, None, world_size):
# sample fewshot context #TODO: need to offset doc_id by rank now! # sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc, self._config.num_fewshot, rnd=random.Random()
......
...@@ -6,7 +6,7 @@ import lm_eval.api.metrics ...@@ -6,7 +6,7 @@ import lm_eval.api.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.api import lm_eval.api
from lm_eval.utils import positional_deprecated, run_task_tests, make_table from lm_eval.utils import positional_deprecated, run_task_tests, make_table, create_iterator
import torch import torch
@positional_deprecated @positional_deprecated
...@@ -146,7 +146,6 @@ def evaluate( ...@@ -146,7 +146,6 @@ def evaluate(
# rnd.seed(42) # rnd.seed(42)
# rnd.shuffle(task_docs) # rnd.shuffle(task_docs)
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task.build_all_requests(limit=limit, rank = lm.rank, world_size = lm.world_size) task.build_all_requests(limit=limit, rank = lm.rank, world_size = lm.world_size)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py
...@@ -156,11 +155,9 @@ def evaluate( ...@@ -156,11 +155,9 @@ def evaluate(
instances_rnk = torch.tensor(len(task._instances), device = lm.device) instances_rnk = torch.tensor(len(task._instances), device = lm.device)
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
# compute number of pseudobatches to pad with (FSDP/DDP require even batches + can't use join) # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
# we assume rank 0 always has largest iterator # we assume rank 0 always has largest iterator
numpad = gathered_item[0] - gathered_item[lm.rank] numpad = gathered_item[0] - gathered_item[lm.rank]
if numpad > 0:
print(f"{task_name} / balancing iterators across ranks / rank: {lm.rank} / + {numpad} sample")
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
...@@ -200,7 +197,8 @@ def evaluate( ...@@ -200,7 +197,8 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner) # calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key # TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
for doc_id, doc in itertools.islice(enumerate(task.test_docs()), lm.rank, None, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, None, lm.world_size): doc_iterator = itertools.islice(enumerate(task.test_docs()), lm.rank, None, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, None, lm.world_size)
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
......
...@@ -73,7 +73,7 @@ class HFLM(LM): ...@@ -73,7 +73,7 @@ class HFLM(LM):
self.accelerator = accelerator self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
print(f"Using {gpus} GPUs with FullyShardedDataParalell and accelerate") print(f"Using {gpus} GPUs with Data Parallelism")
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = gpus self._world_size = gpus
...@@ -202,7 +202,7 @@ class HFLM(LM): ...@@ -202,7 +202,7 @@ class HFLM(LM):
# TODO: automatic (variable) batch size detection for vectorization # TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size tqdm(re_ord.get_reordered(), disable=(disable_tqdm or not (self.rank == 0))), self.batch_size
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
......
...@@ -9,7 +9,7 @@ from typing import List ...@@ -9,7 +9,7 @@ from typing import List
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -246,3 +246,12 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined) ...@@ -246,3 +246,12 @@ env = Environment(loader=BaseLoader, undefined=StrictUndefined)
def apply_template(template, doc): def apply_template(template, doc):
rtemplate = env.from_string(template) rtemplate = env.from_string(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
def create_iterator(raw_iterator, rank, world_size, limit = None):
"""
Method for creating a (potentially) sliced and limited
iterator from a raw document iterator. Used for splitting data
among ranks in multigpu setting or only pulling a sample of documents
"""
return islice(raw_iterator, rank, limit, world_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment