Unverified Commit ca36c41c authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #673 from fattorib/big-refactor-autobatching

[Refactor] Port over Autobatching
parents 1ef21c7b e0281126
......@@ -82,6 +82,17 @@ python main.py \
Models that are loaded via either `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) or `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supported via Support for this model type is currently pending.
Batch size selection can be automated by setting the ```--batch_size``` flag to ```auto```. This will perform automatic detection of the largest batch size that will fit on your device. On tasks where there is a large difference between the longest and shortest example, it can be helpful to periodically recompute the largest batch size, to gain a further speedup. To do this, append ```:N``` to above flag to automatically recompute the largest batch size ```N``` times. For example, to recompute the batch size 4 times, the command would be:
```bash
python main.py \
--model hf \
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
--tasks lambada_openai,hellaswag \
--device cuda:0 \
--batch_size auto:4
```
### Multi-GPU Evaluation with Hugging Face `accelerate`
To parallelize evaluation of HuggingFace models across multiple GPUs, we allow for two different types of multi-GPU evaluation.
......
......@@ -20,7 +20,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from accelerate import Accelerator, find_executable_batch_size
from typing import List, Optional, Union
......@@ -70,7 +70,8 @@ class HFLM(LM):
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[int] = 1,
batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64,
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
......@@ -94,7 +95,7 @@ class HFLM(LM):
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count()
accelerator = Accelerator()
......@@ -244,8 +245,16 @@ class HFLM(LM):
self._max_length = max_length
# multithreading and batching
self.batch_size_per_gpu = batch_size
self.batch_schedule = 1
self.batch_sizes = {}
self.max_batch_size = max_batch_size
if str(batch_size).startswith("auto"):
batch_size = batch_size.split(":")
self.batch_size_per_gpu = batch_size[0]
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
else:
self.batch_size_per_gpu = int(batch_size)
# multigpu data-parallel support when launched with accelerate
if gpus > 1:
......@@ -342,6 +351,52 @@ class HFLM(LM):
def world_size(self):
return self._world_size
def _detect_batch_size(self, requests=None, pos=0):
if requests:
_, context_enc, continuation_enc = requests[pos]
max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
else:
max_length = self.max_length
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones((batch_size, length), device=self.device).long()
test_batch = torch.ones((batch_size, length), device=self.device).long()
call_kwargs = {
"attn_mask": test_batch,
"labels": batched_conts,
}
else:
call_kwargs = {}
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)
return batch_size
batch_size = forward_batch()
if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs = torch.tensor([batch_size], device=self.device)
gathered = (
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
)
batch_size = min(gathered)
utils.clear_torch_cache()
return batch_size
utils.clear_torch_cache()
return batch_size
def tok_encode(self, string: str, left_truncate_len=None):
""" """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
......@@ -480,6 +535,15 @@ class HFLM(LM):
def loglikelihood_rolling(self, requests):
loglikelihoods = []
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
......@@ -509,7 +573,7 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size
)
if (self.world_size > 1) and (pad_amnt > 0):
......@@ -523,7 +587,7 @@ class HFLM(LM):
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs=None):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -537,12 +601,37 @@ class HFLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched-1] == self.max_batch_size):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(re_ord.get_reordered(), pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
n=self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0,
fn=_batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
else None,
):
inps = []
cont_toks_list = []
......
......@@ -32,7 +32,7 @@ def parse_args():
default=None,
help="Number of examples in few-shot context",
)
parser.add_argument("--batch_size", type=int, default=1) # TODO: only integers
parser.add_argument("--batch_size", type=str, default=1)
parser.add_argument(
"--max_batch_size",
type=int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment