Commit 99304fe5 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

wrap self._model_call with F.log_softmax + remove empty_cache, not needed

parent 8a89b30c
...@@ -197,7 +197,7 @@ class BaseLM(LM): ...@@ -197,7 +197,7 @@ class BaseLM(LM):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, self.max_length), device=self.device).long() test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
self._model_call(test_batch) F.log_softmax(self._model_call(test_batch), dim = -1)
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
...@@ -264,7 +264,7 @@ class BaseLM(LM): ...@@ -264,7 +264,7 @@ class BaseLM(LM):
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long() test_batch = torch.ones((batch_size, max_context), device=self.device).long()
self._model_call(test_batch) F.log_softmax(self._model_call(test_batch), dim = -1)
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
...@@ -274,8 +274,6 @@ class BaseLM(LM): ...@@ -274,8 +274,6 @@ class BaseLM(LM):
else: else:
adaptive_batch_size = override_bs adaptive_batch_size = override_bs
torch.cuda.empty_cache() # empty cache after determining batch size
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment