Unverified Commit 992f021b authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #912 from AndyWolfZwei/andy/big-refactor-auto-batch

[Refactor] Add _batch_scheduler in greedy_until
parents 2afe7770 660dfb71
......@@ -43,7 +43,7 @@ jobs:
# # mypy turned off for now
# - name: Lint with mypy
# run: mypy . --ignore-missing-imports --check-untyped-defs --explicit-package-bases --warn-unreachable
Job 2
Job 2:
testcpu:
name: CPU Tests
runs-on: ubuntu-latest
......
......@@ -621,6 +621,23 @@ class HFLM(LM):
return loglikelihoods
def _batch_scheduler(self, pos, n_reordered_requests):
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None
):
......@@ -644,25 +661,6 @@ class HFLM(LM):
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
n=self.batch_size
......@@ -670,7 +668,7 @@ class HFLM(LM):
else override_bs
if override_bs is not None
else 0,
fn=_batch_scheduler
fn=self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
......@@ -838,12 +836,24 @@ class HFLM(LM):
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
for chunk in utils.chunks(
re_ord.get_reordered(),
self.batch_size,
tqdm(re_ord.get_reordered(), disable=self.rank != 0),
n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto" and not adaptive_batch_size
else None,
):
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
......
......@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None):
arr = []
for i, x in enumerate(iter):
arr.append(x)
if len(arr) == (fn(i) if fn else n):
if len(arr) == (fn(i, iter) if fn else n):
yield arr
arr = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment