Commit d4c5315a authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

sync working changes with upstream

parent 2da74953
......@@ -248,7 +248,7 @@ class Task(abc.ABC):
def doc_to_target(self, doc):
pass
def build_all_requests(self, limit=None):
def build_all_requests(self, limit=None, rank=None, world_size=None):
"""Build a set of Instances for a task, and store them in task.instances"""
if self.has_test_docs():
docs = self.test_docs()
......@@ -260,8 +260,9 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
instances = []
for doc_id, doc in enumerate(itertools.islice(docs, 0, limit) if limit else docs):
# sample fewshot context
# for doc_id, doc in enumerate(itertools.islice(docs, 0, limit) if limit else docs):
for doc_id, doc in itertools.islice(enumerate(docs), rank, None, world_size):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random()
)
......
......@@ -7,7 +7,7 @@ import lm_eval.models
import lm_eval.tasks
import lm_eval.api
from lm_eval.utils import positional_deprecated, run_task_tests, make_table
import torch
@positional_deprecated
def simple_evaluate(
......@@ -79,19 +79,23 @@ def simple_evaluate(
decontamination_ngrams_path=decontamination_ngrams_path,
)
# add info about the model and few shot config
results["config"] = {
"model": model,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"device": device,
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
if lm.rank == 0:
# add info about the model and few shot config
results["config"] = {
"model": model,
"model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size,
"device": device,
"no_cache": no_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
return results
else:
return None
return results
decontaminate_suffix = "_decontaminate"
......@@ -143,10 +147,20 @@ def evaluate(
# rnd.shuffle(task_docs)
# for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)):
task.build_all_requests(limit=limit)
task.build_all_requests(limit=limit, rank = lm.rank, world_size = lm.world_size)
# aggregate Instances by LM method requested to get output.
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device = lm.device)
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
# compute number of pseudobatches to pad with (FSDP/DDP require even batches + can't use join)
# we assume rank 0 always has largest iterator
numpad = gathered_item[0] - gathered_item[lm.rank]
if numpad > 0:
print(f"{task_name} / balancing iterators across ranks / rank: {lm.rank} / + {numpad} sample")
### Run LM on inputs, get all outputs ###
# execute each type of request
......@@ -157,6 +171,10 @@ def evaluate(
for req in reqs:
cloned_reqs.extend([req] * req.repeats)
if (lm.rank > 0) and (numpad > 0):
for _ in range(numpad):
cloned_reqs.extend([req] * req.repeats)
# run requests through model
resps = getattr(lm, reqtype)(cloned_reqs)
......@@ -164,6 +182,9 @@ def evaluate(
for x, req in zip(resps, cloned_reqs):
req.resps.append(x)
if lm.world_size > 1:
lm.accelerator.wait_for_everyone()
### Postprocess outputs ###
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
for task_name, task in task_dict.items():
......@@ -187,25 +208,61 @@ def evaluate(
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
if lm.world_size > 1:
# if multigpu, then gather data across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
# distributed gather requires all ranks to have same dimensionality -> pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device = lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(metrics_tensor.to(torch.float32), pad_index = pad_value)
gathered_item = lm.accelerator.gather(torch_device_tensor)
#TODO: This is required when we get a tensor with a tuple of info like (ppl, _bytes) from wikitext
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:,0] != pad_value]
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item
vals = vals_torch
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](items)
if lm.rank == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
stderr = lm_eval.api.metrics.stderr_for_metric(
metric=task.aggregation()[metric],
bootstrap_iters=min(bootstrap_iters, 1000)
if metric in ["bleu", "chrf", "ter"]
else bootstrap_iters,
)
if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(items)
if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(items)
return {"results": dict(results), "versions": dict(versions)}
return {"results": dict(results), "versions": dict(versions)}
else:
return None
......@@ -8,6 +8,8 @@ import torch.nn.functional as F
from lm_eval import utils
from lm_eval.api.model import LM, register_model
from accelerate import Accelerator
from itertools import islice
@register_model("hf-causal", "gpt2")
class HFLM(LM):
......@@ -26,20 +28,22 @@ class HFLM(LM):
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
print(f"Using device '{device}'")
else:
print("Device not specified")
print(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
......@@ -59,10 +63,17 @@ class HFLM(LM):
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if gpus > 1:
accelerator = Accelerator(device_placement=False)
self.gpt2 = accelerator.prepare(self.gpt2)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
print(f"Using {gpus} GPUs with FullyShardedDataParalell and accelerate")
self._rank = self.accelerator.local_process_index
self._world_size = gpus
@property
def eot_token_id(self):
......@@ -90,6 +101,14 @@ class HFLM(LM):
def device(self):
# TODO: fix multi-gpu
return self._device
@property
def rank(self):
return self._rank
@property
def world_size(self):
return self._world_size
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
......
......@@ -89,30 +89,31 @@ def main():
print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
with open(args.output_path, "w") as f:
f.write(dumped)
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
)
print(evaluator.make_table(results))
if results is not None:
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.output_path:
with open(args.output_path, "w") as f:
f.write(dumped)
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}"
)
print(evaluator.make_table(results))
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment