Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
03b41be1
Unverified
Commit
03b41be1
authored
Jun 07, 2021
by
guoshzhao
Committed by
GitHub
Jun 07, 2021
Browse files
Benchmarks: Fix Bug - Fix OOM issue when run pytorch models sequentially. (#93)
* Clean up the cache.
parent
2d9be807
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
24 deletions
+61
-24
superbench/benchmarks/base.py
superbench/benchmarks/base.py
+18
-17
superbench/benchmarks/model_benchmarks/pytorch_base.py
superbench/benchmarks/model_benchmarks/pytorch_base.py
+6
-0
superbench/benchmarks/return_code.py
superbench/benchmarks/return_code.py
+1
-0
tests/benchmarks/model_benchmarks/test_pytorch_base.py
tests/benchmarks/model_benchmarks/test_pytorch_base.py
+36
-7
No files found.
superbench/benchmarks/base.py
View file @
03b41be1
...
@@ -128,24 +128,25 @@ def run(self):
...
@@ -128,24 +128,25 @@ def run(self):
Return:
Return:
True if run benchmark successfully.
True if run benchmark successfully.
"""
"""
if
not
self
.
_preprocess
():
ret
=
True
return
False
try
:
ret
&=
self
.
_preprocess
()
if
ret
:
self
.
_start_time
=
datetime
.
utcnow
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
self
.
_start_time
=
datetime
.
utcnow
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
for
self
.
_curr_run_index
in
range
(
self
.
_args
.
run_count
):
for
self
.
_curr_run_index
in
range
(
self
.
_args
.
run_count
):
if
not
self
.
_benchmark
():
ret
&=
self
.
_benchmark
()
return
False
self
.
_end_time
=
datetime
.
utcnow
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
self
.
_end_time
=
datetime
.
utcnow
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
self
.
_result
.
set_timestamp
(
self
.
_start_time
,
self
.
_end_time
)
self
.
_result
.
set_timestamp
(
self
.
_start_time
,
self
.
_end_time
)
if
not
self
.
__check_result_format
():
if
ret
:
return
False
ret
&=
self
.
__check_result_format
()
except
BaseException
as
e
:
if
not
self
.
_postprocess
():
self
.
_result
.
set_return_code
(
ReturnCode
.
RUNTIME_EXCEPTION_ERROR
)
return
False
logger
.
error
(
'Run benchmark failed - benchmark: {}, message: {}'
.
format
(
self
.
_name
,
str
(
e
)))
finally
:
ret
&=
self
.
_postprocess
()
return
True
return
ret
def
__check_result_format
(
self
):
def
__check_result_format
(
self
):
"""Check the validation of result object.
"""Check the validation of result object.
...
...
superbench/benchmarks/model_benchmarks/pytorch_base.py
View file @
03b41be1
...
@@ -183,6 +183,12 @@ def _postprocess(self):
...
@@ -183,6 +183,12 @@ def _postprocess(self):
)
)
return
False
return
False
del
self
.
_model
del
self
.
_optimizer
del
self
.
_target
torch
.
cuda
.
empty_cache
()
return
True
return
True
def
_cal_params_count
(
self
):
def
_cal_params_count
(
self
):
...
...
superbench/benchmarks/return_code.py
View file @
03b41be1
...
@@ -13,6 +13,7 @@ class ReturnCode(Enum):
...
@@ -13,6 +13,7 @@ class ReturnCode(Enum):
INVALID_ARGUMENT
=
1
INVALID_ARGUMENT
=
1
INVALID_BENCHMARK_TYPE
=
2
INVALID_BENCHMARK_TYPE
=
2
INVALID_BENCHMARK_RESULT
=
3
INVALID_BENCHMARK_RESULT
=
3
RUNTIME_EXCEPTION_ERROR
=
4
# Return codes related with model benchmarks.
# Return codes related with model benchmarks.
NO_SUPPORTED_PRECISION
=
10
NO_SUPPORTED_PRECISION
=
10
DISTRIBUTED_SETTING_INIT_FAILURE
=
13
DISTRIBUTED_SETTING_INIT_FAILURE
=
13
...
...
tests/benchmarks/model_benchmarks/test_pytorch_base.py
View file @
03b41be1
...
@@ -173,17 +173,15 @@ def _inference_step(self, precision):
...
@@ -173,17 +173,15 @@ def _inference_step(self, precision):
@
decorator
.
pytorch_test
@
decorator
.
pytorch_test
def
test_pytorch_base
():
def
test_pytorch_base
():
"""Test PytorchBase class."""
"""Test PytorchBase class."""
# Register
BERT Base
benchmark.
# Register
mnist
benchmark.
BenchmarkRegistry
.
register_benchmark
(
'pytorch-mnist'
,
PytorchMNIST
)
BenchmarkRegistry
.
register_benchmark
(
'pytorch-mnist'
,
PytorchMNIST
)
# Launch benchmark with --no_gpu for testing.
# Launch benchmark with --no_gpu for testing.
context
=
BenchmarkRegistry
.
create_benchmark_context
(
parameters
=
'--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train inference --no_gpu'
'pytorch-mnist'
,
benchmark
=
PytorchMNIST
(
'pytorch-mnist'
,
parameters
=
parameters
)
parameters
=
'--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train inference --no_gpu'
)
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
assert
(
benchmark
)
assert
(
benchmark
)
assert
(
benchmark
.
_preprocess
())
assert
(
benchmark
.
_benchmark
())
assert
(
benchmark
.
name
==
'pytorch-mnist'
)
assert
(
benchmark
.
name
==
'pytorch-mnist'
)
assert
(
benchmark
.
return_code
==
ReturnCode
.
SUCCESS
)
assert
(
benchmark
.
return_code
==
ReturnCode
.
SUCCESS
)
...
@@ -231,3 +229,34 @@ def test_pytorch_base():
...
@@ -231,3 +229,34 @@ def test_pytorch_base():
assert
(
isinstance
(
benchmark
.
_optimizer
,
torch
.
optim
.
SGD
))
assert
(
isinstance
(
benchmark
.
_optimizer
,
torch
.
optim
.
SGD
))
benchmark
.
_optimizer_type
=
None
benchmark
.
_optimizer_type
=
None
assert
(
benchmark
.
_create_optimizer
()
is
False
)
assert
(
benchmark
.
_create_optimizer
()
is
False
)
# Test _postprocess().
assert
(
benchmark
.
_postprocess
())
@
decorator
.
cuda_test
@
decorator
.
pytorch_test
def
test_pytorch_empty_cache
():
"""Test PytorchBase class."""
# Register mnist benchmark.
BenchmarkRegistry
.
register_benchmark
(
'pytorch-mnist'
,
PytorchMNIST
)
# Test cache empty by manually calling torch.cuda.empty_cache().
parameters
=
'--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train'
benchmark
=
PytorchMNIST
(
'pytorch-mnist'
,
parameters
=
parameters
)
assert
(
benchmark
)
assert
(
benchmark
.
_preprocess
())
assert
(
benchmark
.
_benchmark
())
del
benchmark
assert
(
torch
.
cuda
.
memory_stats
()[
'reserved_bytes.all.current'
]
>
0
)
torch
.
cuda
.
empty_cache
()
assert
(
torch
.
cuda
.
memory_stats
()[
'reserved_bytes.all.current'
]
==
0
)
# Test automatic cache empty.
context
=
BenchmarkRegistry
.
create_benchmark_context
(
'pytorch-mnist'
,
parameters
=
'--batch_size 32 --num_warmup 8 --num_steps 64 --model_action train'
)
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
assert
(
benchmark
)
assert
(
torch
.
cuda
.
memory_stats
()[
'reserved_bytes.all.current'
]
==
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment