"mmdet3d/models/vscode:/vscode.git/clone" did not exist on "ff1e5b4ef45bbcf48d253e08510da31fd492a228"
Unverified Commit f862a118 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #572 from gakada/perf

Add --max_batch_size and --batch_size auto:N
parents 23dcc12e 8cec82b2
...@@ -119,6 +119,12 @@ class LM(abc.ABC): ...@@ -119,6 +119,12 @@ class LM(abc.ABC):
class BaseLM(LM): class BaseLM(LM):
def __init__(self):
super().__init__()
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = 512
@property @property
@abstractmethod @abstractmethod
def eot_token_id(self): def eot_token_id(self):
...@@ -167,6 +173,26 @@ class BaseLM(LM): ...@@ -167,6 +173,26 @@ class BaseLM(LM):
""" """
pass pass
def _detect_batch_size(self, requests=None, pos=0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
else:
max_length = self.max_length
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
utils.clear_torch_cache()
return batch_size
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length. # subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
# TODO: enforce this somehow # TODO: enforce this somehow
...@@ -202,19 +228,7 @@ class BaseLM(LM): ...@@ -202,19 +228,7 @@ class BaseLM(LM):
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
...@@ -267,34 +281,24 @@ class BaseLM(LM): ...@@ -267,34 +281,24 @@ class BaseLM(LM):
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
reordered_requests = re_ord.get_reordered()
n_reordered_requests = len(reordered_requests)
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
if len(re_ord.get_reordered()) > 0: def _batch_scheduler(pos):
_, context_enc, continuation_enc = re_ord.get_reordered()[0] sched = pos // int(n_reordered_requests / self.batch_schedule)
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) if sched in self.batch_sizes:
if (self.batch_size == 'auto'): return self.batch_sizes[sched]
print(f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size")
if override_bs is None: self.batch_sizes[sched] = self._detect_batch_size(reordered_requests, pos)
print('Passed argument batch_size = auto. Detecting largest batch size') print(f"Determined largest batch size: {self.batch_sizes[sched]}")
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again return self.batch_sizes[sched]
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size
else:
adaptive_batch_size = override_bs
else:
adaptive_batch_size = 0 if override_bs is None else override_bs
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), tqdm(reordered_requests, disable=disable_tqdm),
self.batch_size if self.batch_size != "auto" else adaptive_batch_size, n=self.batch_size if self.batch_size != "auto" else override_bs if override_bs is not None else 0,
fn=_batch_scheduler if self.batch_size == "auto" and n_reordered_requests > 0 else None,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
......
...@@ -16,6 +16,7 @@ def simple_evaluate( ...@@ -16,6 +16,7 @@ def simple_evaluate(
tasks=[], tasks=[],
num_fewshot=0, num_fewshot=0,
batch_size=None, batch_size=None,
max_batch_size=None,
device=None, device=None,
no_cache=False, no_cache=False,
limit=None, limit=None,
...@@ -37,8 +38,10 @@ def simple_evaluate( ...@@ -37,8 +38,10 @@ def simple_evaluate(
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int :param num_fewshot: int
Number of examples in few-shot context Number of examples in few-shot context
:param batch_size: int, optional :param batch_size: int or str, optional
Batch size for model Batch size for model
:param max_batch_size: int, optional
Maximal batch size to try with automatic batch size detection
:param device: str, optional :param device: str, optional
PyTorch device (e.g. "cpu" or "cuda:0") for running models PyTorch device (e.g. "cpu" or "cuda:0") for running models
:param no_cache: bool :param no_cache: bool
...@@ -67,7 +70,7 @@ def simple_evaluate( ...@@ -67,7 +70,7 @@ def simple_evaluate(
if model_args is None: if model_args is None:
model_args = "" model_args = ""
lm = lm_eval.models.get_model(model).create_from_arg_string( lm = lm_eval.models.get_model(model).create_from_arg_string(
model_args, {"batch_size": batch_size, "device": device} model_args, {"batch_size": batch_size, "max_batch_size": max_batch_size, "device": device}
) )
else: else:
assert isinstance(model, lm_eval.base.LM) assert isinstance(model, lm_eval.base.LM)
...@@ -106,6 +109,7 @@ def simple_evaluate( ...@@ -106,6 +109,7 @@ def simple_evaluate(
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot, "num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()),
"device": device, "device": device,
"no_cache": no_cache, "no_cache": no_cache,
"limit": limit, "limit": limit,
......
...@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
from transformers import BatchEncoding from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils from lm_eval import utils
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
...@@ -76,6 +75,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -76,6 +75,7 @@ class HuggingFaceAutoLM(BaseLM):
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
revision: Optional[str] = "main", revision: Optional[str] = "main",
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 512,
max_gen_toks: Optional[int] = 256, max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None, max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None, add_special_tokens: Optional[bool] = None,
...@@ -172,10 +172,13 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -172,10 +172,13 @@ class HuggingFaceAutoLM(BaseLM):
), "Evaluating causal models with `add_special_tokens=True` is currently not supported." ), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
# setup for automatic batch size detection # setup for automatic batch size detection
if batch_size == "auto": if str(batch_size).startswith("auto"):
self._batch_size = batch_size batch_size = batch_size.split(":")
self._batch_size = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else: else:
self._batch_size = int(batch_size) self._batch_size = int(batch_size)
self.max_batch_size = max_batch_size
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
self._max_length = max_length self._max_length = max_length
...@@ -411,19 +414,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -411,19 +414,7 @@ class HuggingFaceAutoLM(BaseLM):
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
@find_executable_batch_size(
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones(
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
......
...@@ -8,6 +8,7 @@ import sys ...@@ -8,6 +8,7 @@ import sys
import fnmatch import fnmatch
from typing import List, Union from typing import List, Union
import gc
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -64,11 +65,11 @@ def join_iters(iters): ...@@ -64,11 +65,11 @@ def join_iters(iters):
yield from iter yield from iter
def chunks(iter, n): def chunks(iter, n=0, fn=None):
arr = [] arr = []
for x in iter: for i, x in enumerate(iter):
arr.append(x) arr.append(x)
if len(arr) == n: if len(arr) == (fn(i) if fn else n):
yield arr yield arr
arr = [] arr = []
...@@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]): ...@@ -283,3 +284,8 @@ def run_task_tests(task_list: List[str]):
raise ValueError( raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
) )
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
...@@ -16,6 +16,8 @@ def parse_args(): ...@@ -16,6 +16,8 @@ def parse_args():
parser.add_argument("--provide_description", action="store_true") parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0) parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=str, default=None) parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument("--max_batch_size", type=int, default=None,
help="Maximal batch size to try with --batch_size auto")
parser.add_argument("--device", type=str, default=None) parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None) parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=float, default=None, parser.add_argument("--limit", type=float, default=None,
...@@ -60,6 +62,7 @@ def main(): ...@@ -60,6 +62,7 @@ def main():
tasks=task_names, tasks=task_names,
num_fewshot=args.num_fewshot, num_fewshot=args.num_fewshot,
batch_size=args.batch_size, batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device, device=args.device,
no_cache=args.no_cache, no_cache=args.no_cache,
limit=args.limit, limit=args.limit,
...@@ -78,9 +81,10 @@ def main(): ...@@ -78,9 +81,10 @@ def main():
with open(args.output_path, "w") as f: with open(args.output_path, "w") as f:
f.write(dumped) f.write(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
print( print(
f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, " f"{args.model} ({args.model_args}), limit: {args.limit}, provide_description: {args.provide_description}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}" f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
) )
print(evaluator.make_table(results)) print(evaluator.make_table(results))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment