"test/vscode:/vscode.git/clone" did not exist on "a5a892ffd3d38d30a8ec2e7e725efb8ec2daafd0"
Commit 1aa3bc1e authored by Zhiwei Zhuang's avatar Zhiwei Zhuang
Browse files

add _batch_scheduler in greedy_until

parent 2afe7770
...@@ -620,6 +620,25 @@ class HFLM(LM): ...@@ -620,6 +620,25 @@ class HFLM(LM):
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
def _batch_scheduler(self, pos, n_reordered_requests):
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(
n_reordered_requests, pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False, override_bs=None self, requests, disable_tqdm: bool = False, override_bs=None
...@@ -644,25 +663,6 @@ class HFLM(LM): ...@@ -644,25 +663,6 @@ class HFLM(LM):
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes:
return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched]
print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
)
self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched]
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
n=self.batch_size n=self.batch_size
...@@ -670,7 +670,7 @@ class HFLM(LM): ...@@ -670,7 +670,7 @@ class HFLM(LM):
else override_bs else override_bs
if override_bs is not None if override_bs is not None
else 0, else 0,
fn=_batch_scheduler fn=self._batch_scheduler
if self.batch_size == "auto" if self.batch_size == "auto"
and n_reordered_requests > 0 and n_reordered_requests > 0
and not override_bs and not override_bs
...@@ -838,13 +838,27 @@ class HFLM(LM): ...@@ -838,13 +838,27 @@ class HFLM(LM):
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(self.rank != 0))
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
batch_size = self._detect_batch_size()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
for chunk in utils.chunks( for chunk in utils.chunks(
re_ord.get_reordered(), re_ord.get_reordered(),
self.batch_size, n=self.batch_size
if self.batch_size != "auto"
else adaptive_batch_size
if adaptive_batch_size is not None
else 0,
fn=self._batch_scheduler
if self.batch_size == "auto"
and not adaptive_batch_size
else None,
): ):
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it. # this is safe to assume because the `grouper` object ensures it.
......
...@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None): ...@@ -78,7 +78,7 @@ def chunks(iter, n: int = 0, fn=None):
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
if len(arr) == (fn(i) if fn else n): if len(arr) == (fn(i, iter) if fn else n):
yield arr yield arr
arr = [] arr = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment