Commit f1dedbf0 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

ll_rolling computes adaptive bs separately and passes computed bs to...

ll_rolling computes adaptive bs separately and passes computed bs to _ll_tokens to avoid recomputation
parent b824fc91
...@@ -188,7 +188,21 @@ class BaseLM(LM): ...@@ -188,7 +188,21 @@ class BaseLM(LM):
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation # TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
# automatic batch size detection for vectorization
adaptive_batch_size = None
if self.batch_size == 'auto':
# using rolling window with maximum context
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
self._model_call(test_batch)
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm(requests): for (string,) in tqdm(requests):
...@@ -209,7 +223,7 @@ class BaseLM(LM): ...@@ -209,7 +223,7 @@ class BaseLM(LM):
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that # that
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True rolling_token_windows, disable_tqdm=True, override_bs = adaptive_batch_size
) )
# discard is_greedy # discard is_greedy
...@@ -220,7 +234,7 @@ class BaseLM(LM): ...@@ -220,7 +234,7 @@ class BaseLM(LM):
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False, override_bs = None):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res = [] res = []
...@@ -241,18 +255,24 @@ class BaseLM(LM): ...@@ -241,18 +255,24 @@ class BaseLM(LM):
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
_, context_enc, continuation_enc = re_ord.get_reordered()[0] _, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len(context_enc) + len(continuation_enc) max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if self.batch_size == 'auto':
print('Passed argument batch_size = auto. Detecting largest batch size') if (self.batch_size == 'auto'):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
self._model_call(test_batch)
return batch_size
batch_size = forward_batch() if override_bs is None:
print(f"Determined Largest batch size: {batch_size}") print('Passed argument batch_size = auto. Detecting largest batch size')
adaptive_batch_size = batch_size @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
self._model_call(test_batch)
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
else:
adaptive_batch_size = override_bs
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment