Unverified Commit 8032b8d2 authored by Michael Baumgartner's avatar Michael Baumgartner Committed by GitHub
Browse files

make errors more general

parent 2aa85703
...@@ -191,10 +191,11 @@ class MemoryEstimatorDetection(MemoryEstimator): ...@@ -191,10 +191,11 @@ class MemoryEstimatorDetection(MemoryEstimator):
device = torch.device("cuda", self.gpu_id) device = torch.device("cuda", self.gpu_id)
logger.info(f"Estimating on {device} with shape {shape} and " logger.info(f"Estimating on {device} with shape {shape} and "
f"batch size {self.batch_size} and num_instances {num_instances}") f"batch size {self.batch_size} and num_instances {num_instances}")
try:
loss = None loss = None
opt = None opt = None
inp = None inp = None
block_tensor = None
try:
with cudnn_deterministic(): with cudnn_deterministic():
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats()
network = network.to(device) network = network.to(device)
...@@ -237,14 +238,13 @@ class MemoryEstimatorDetection(MemoryEstimator): ...@@ -237,14 +238,13 @@ class MemoryEstimatorDetection(MemoryEstimator):
scaler.step(opt) scaler.step(opt)
scaler.update() scaler.update()
dyn_mem = torch.cuda.memory_reserved() dyn_mem = torch.cuda.memory_reserved()
except (RuntimeError,) as e: except Exception as e:
logger.info(f"Caught error (If out of memory error do not worry): {e}") logger.info(f"Caught error (If out of memory error do not worry): {e}")
empty_mem = 0 empty_mem = 0
fixed_mem = float('Inf') fixed_mem = float('Inf')
dyn_mem = float('Inf') dyn_mem = float('Inf')
finally: finally:
del loss del loss
del opt del opt
del inp del inp
del block_tensor del block_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment