Commit 42c6b7df authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

additional external call to empty_cache + gc collect

parent 99304fe5
...@@ -12,7 +12,7 @@ from tqdm import tqdm ...@@ -12,7 +12,7 @@ from tqdm import tqdm
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from accelerate import find_executable_batch_size from accelerate import find_executable_batch_size
import gc
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils from lm_eval import utils
...@@ -197,12 +197,14 @@ class BaseLM(LM): ...@@ -197,12 +197,14 @@ class BaseLM(LM):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, self.max_length), device=self.device).long() test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
F.log_softmax(self._model_call(test_batch), dim = -1) out = F.log_softmax(self._model_call(test_batch), dim = -1)
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
torch.cuda.empty_cache()
gc.collect()
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm(requests): for (string,) in tqdm(requests):
...@@ -254,26 +256,27 @@ class BaseLM(LM): ...@@ -254,26 +256,27 @@ class BaseLM(LM):
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
_, context_enc, continuation_enc = re_ord.get_reordered()[0] _, context_enc, continuation_enc = re_ord.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]) max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
if (self.batch_size == 'auto'): if (self.batch_size == 'auto'):
if override_bs is None: if override_bs is None:
print('Passed argument batch_size = auto. Detecting largest batch size') print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long() test_batch = test_batch = torch.ones((batch_size, max_context), device=self.device).long()
F.log_softmax(self._model_call(test_batch), dim = -1) out = F.log_softmax(self._model_call(test_batch), dim = -1)
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}") print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
torch.cuda.empty_cache()
gc.collect()
else: else:
adaptive_batch_size = override_bs adaptive_batch_size = override_bs
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment