Unverified Commit 03b41be1 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Fix Bug - Fix OOM issue when run pytorch models sequentially. (#93)

* Clean up the cache.
parent 2d9be807
...@@ -128,24 +128,25 @@ def run(self): ...@@ -128,24 +128,25 @@ def run(self):
Return: Return:
True if run benchmark successfully. True if run benchmark successfully.
""" """
if not self._preprocess(): ret = True
return False try:
ret &= self._preprocess()
if ret:
self._start_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') self._start_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
for self._curr_run_index in range(self._args.run_count): for self._curr_run_index in range(self._args.run_count):
if not self._benchmark(): ret &= self._benchmark()
return False
self._end_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') self._end_time = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self._result.set_timestamp(self._start_time, self._end_time) self._result.set_timestamp(self._start_time, self._end_time)
if not self.__check_result_format(): if ret:
return False ret &= self.__check_result_format()
except BaseException as e:
if not self._postprocess(): self._result.set_return_code(ReturnCode.RUNTIME_EXCEPTION_ERROR)
return False logger.error('Run benchmark failed - benchmark: {}, message: {}'.format(self._name, str(e)))
finally:
ret &= self._postprocess()
return True return ret
def __check_result_format(self): def __check_result_format(self):
"""Check the validation of result object. """Check the validation of result object.
......
...@@ -183,6 +183,12 @@ def _postprocess(self): ...@@ -183,6 +183,12 @@ def _postprocess(self):
) )
return False return False
del self._model
del self._optimizer
del self._target
torch.cuda.empty_cache()
return True return True
def _cal_params_count(self): def _cal_params_count(self):
......
...@@ -13,6 +13,7 @@ class ReturnCode(Enum): ...@@ -13,6 +13,7 @@ class ReturnCode(Enum):
INVALID_ARGUMENT = 1 INVALID_ARGUMENT = 1
INVALID_BENCHMARK_TYPE = 2 INVALID_BENCHMARK_TYPE = 2
INVALID_BENCHMARK_RESULT = 3 INVALID_BENCHMARK_RESULT = 3
RUNTIME_EXCEPTION_ERROR = 4
# Return codes related with model benchmarks. # Return codes related with model benchmarks.
NO_SUPPORTED_PRECISION = 10 NO_SUPPORTED_PRECISION = 10
DISTRIBUTED_SETTING_INIT_FAILURE = 13 DISTRIBUTED_SETTING_INIT_FAILURE = 13
......
...@@ -173,17 +173,15 @@ def _inference_step(self, precision): ...@@ -173,17 +173,15 @@ def _inference_step(self, precision):
@decorator.pytorch_test @decorator.pytorch_test
def test_pytorch_base(): def test_pytorch_base():
"""Test PytorchBase class.""" """Test PytorchBase class."""
# Register BERT Base benchmark. # Register mnist benchmark.
BenchmarkRegistry.register_benchmark('pytorch-mnist', PytorchMNIST) BenchmarkRegistry.register_benchmark('pytorch-mnist', PytorchMNIST)
# Launch benchmark with --no_gpu for testing. # Launch benchmark with --no_gpu for testing.
context = BenchmarkRegistry.create_benchmark_context( parameters = '--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train inference --no_gpu'
'pytorch-mnist', benchmark = PytorchMNIST('pytorch-mnist', parameters=parameters)
parameters='--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train inference --no_gpu'
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark) assert (benchmark)
assert (benchmark._preprocess())
assert (benchmark._benchmark())
assert (benchmark.name == 'pytorch-mnist') assert (benchmark.name == 'pytorch-mnist')
assert (benchmark.return_code == ReturnCode.SUCCESS) assert (benchmark.return_code == ReturnCode.SUCCESS)
...@@ -231,3 +229,34 @@ def test_pytorch_base(): ...@@ -231,3 +229,34 @@ def test_pytorch_base():
assert (isinstance(benchmark._optimizer, torch.optim.SGD)) assert (isinstance(benchmark._optimizer, torch.optim.SGD))
benchmark._optimizer_type = None benchmark._optimizer_type = None
assert (benchmark._create_optimizer() is False) assert (benchmark._create_optimizer() is False)
# Test _postprocess().
assert (benchmark._postprocess())
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_empty_cache():
"""Test PytorchBase class."""
# Register mnist benchmark.
BenchmarkRegistry.register_benchmark('pytorch-mnist', PytorchMNIST)
# Test cache empty by manually calling torch.cuda.empty_cache().
parameters = '--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train'
benchmark = PytorchMNIST('pytorch-mnist', parameters=parameters)
assert (benchmark)
assert (benchmark._preprocess())
assert (benchmark._benchmark())
del benchmark
assert (torch.cuda.memory_stats()['reserved_bytes.all.current'] > 0)
torch.cuda.empty_cache()
assert (torch.cuda.memory_stats()['reserved_bytes.all.current'] == 0)
# Test automatic cache empty.
context = BenchmarkRegistry.create_benchmark_context(
'pytorch-mnist', parameters='--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train'
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
assert (benchmark)
assert (torch.cuda.memory_stats()['reserved_bytes.all.current'] == 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment