Commit e87c0838 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

compute multiple forward passes in autobatcher to improve robustness to OOMs

parent d86de51d
...@@ -197,14 +197,13 @@ class BaseLM(LM): ...@@ -197,14 +197,13 @@ class BaseLM(LM):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, self.max_length), device=self.device).long() test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
out = F.log_softmax(self._model_call(test_batch), dim = -1) for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
torch.cuda.empty_cache()
gc.collect()
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm(requests): for (string,) in tqdm(requests):
...@@ -265,15 +264,14 @@ class BaseLM(LM): ...@@ -265,15 +264,14 @@ class BaseLM(LM):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long() test_batch = torch.ones((batch_size, max_context), device=self.device).long()
out = F.log_softmax(self._model_call(test_batch), dim = -1) for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
print(f"Determined largest batch size: {batch_size}") print(f"Determined largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
torch.cuda.empty_cache()
gc.collect()
else: else:
adaptive_batch_size = override_bs adaptive_batch_size = override_bs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment