Unverified Commit 9a877197 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #394 from fattorib/auto-batching

single GPU automatic batching logic
parents fc4428dc d6ceced5
......@@ -11,6 +11,7 @@ from sqlitedict import SqliteDict
from tqdm import tqdm
import torch
import torch.nn.functional as F
from accelerate import find_executable_batch_size
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils
......@@ -186,7 +187,22 @@ class BaseLM(LM):
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
# automatic batch size detection for vectorization
adaptive_batch_size = None
if self.batch_size == 'auto':
# using rolling window with maximum context
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
loglikelihoods = []
for (string,) in tqdm(requests):
......@@ -207,7 +223,7 @@ class BaseLM(LM):
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
rolling_token_windows, disable_tqdm=True, override_bs = adaptive_batch_size
)
# discard is_greedy
......@@ -218,7 +234,7 @@ class BaseLM(LM):
return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs = None):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = []
......@@ -233,10 +249,33 @@ class BaseLM(LM):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
_, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'):
if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size
else:
adaptive_batch_size = override_bs
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
):
inps = []
cont_toks_list = []
......
......@@ -3,7 +3,6 @@ import transformers
from typing import Optional
from lm_eval.base import BaseLM
class HFLM(BaseLM):
def __init__(
self,
......@@ -21,7 +20,7 @@ class HFLM(BaseLM):
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
assert isinstance(batch_size, (int,str))
device_list = set(["cuda", "cpu"] + [f'cuda:{i}' for i in range(torch.cuda.device_count())])
if device and device in device_list:
......@@ -56,13 +55,21 @@ class HFLM(BaseLM):
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# setup for automatic batch size detection
if batch_size == 'auto':
self.batch_size_per_gpu = batch_size
else:
self.batch_size_per_gpu = int(batch_size)
@property
def eot_token_id(self):
......
......@@ -7,6 +7,7 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils
from lm_eval.base import BaseLM
......@@ -71,7 +72,7 @@ class HuggingFaceAutoLM(BaseLM):
tokenizer: Optional[str] = None,
subfolder: Optional[str] = None,
revision: Optional[str] = "main",
batch_size: Optional[int] = 1,
batch_size: Optional[Union[int,str]] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
......@@ -143,7 +144,7 @@ class HuggingFaceAutoLM(BaseLM):
assert isinstance(pretrained, str)
assert isinstance(device, str)
assert isinstance(batch_size, int)
assert isinstance(batch_size, (int, str))
if (
add_special_tokens is not None
and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM
......@@ -157,7 +158,12 @@ class HuggingFaceAutoLM(BaseLM):
not add_special_tokens
), "Evaluating causal models with `add_special_tokens=True` is currently not supported."
self._batch_size = batch_size # TODO: Adaptive batch size
# setup for automatic batch size detection
if batch_size == 'auto':
self._batch_size = batch_size
else:
self._batch_size = int(batch_size)
self._max_gen_toks = max_gen_toks
self._max_length = max_length
self._config = self.AUTO_CONFIG_CLASS.from_pretrained(
......@@ -366,10 +372,30 @@ class HuggingFaceAutoLM(BaseLM):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
_, context_enc, continuation_enc = reorder.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
adaptive_batch_size = None
if self.batch_size == 'auto':
# using rolling window with maximum context
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False), self.batch_size
tqdm(reorder.get_reordered(), disable=False), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
......
......@@ -32,7 +32,7 @@ def parse_args():
parser.add_argument("--tasks", default=None, choices=MultiChoice(tasks.ALL_TASKS))
parser.add_argument("--provide_description", action="store_true")
parser.add_argument("--num_fewshot", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=None)
parser.add_argument("--batch_size", type=str, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output_path", default=None)
parser.add_argument("--limit", type=int, default=None)
......
......@@ -38,6 +38,7 @@ setuptools.setup(
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"accelerate>=0.17.1"
],
extras_require={
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment