Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
2a7ab691
Unverified
Commit
2a7ab691
authored
Apr 20, 2021
by
guoshzhao
Committed by
GitHub
Apr 20, 2021
Browse files
Benchmarks: Add Benchmark - Add LSTM model benchmarks. (#60)
* Benchmarks: Add Benchmark - Add LSTM model benchmarks.
parent
902ea211
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
321 additions
and
7 deletions
+321
-7
examples/benchmarks/pytorch_lstm.py
examples/benchmarks/pytorch_lstm.py
+41
-0
superbench/benchmarks/model_benchmarks/__init__.py
superbench/benchmarks/model_benchmarks/__init__.py
+2
-1
superbench/benchmarks/model_benchmarks/pytorch_bert.py
superbench/benchmarks/model_benchmarks/pytorch_bert.py
+3
-3
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
+3
-3
superbench/benchmarks/model_benchmarks/pytorch_lstm.py
superbench/benchmarks/model_benchmarks/pytorch_lstm.py
+196
-0
tests/benchmarks/model_benchmarks/test_pytorch_lstm.py
tests/benchmarks/model_benchmarks/test_pytorch_lstm.py
+76
-0
No files found.
examples/benchmarks/pytorch_lstm.py
0 → 100644
View file @
2a7ab691
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Model benchmark example for lstm (8-layer, 1024-hidden, 256-input_size, False-bidirectional).
Commands to run:
python3 examples/benchmarks/pytorch_lstm.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_lstm.py
\
--distributed (Distributed)
"""
import
argparse
from
superbench.benchmarks
import
Platform
,
Framework
,
BenchmarkRegistry
from
superbench.common.utils
import
logger
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--distributed'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to enable distributed training.'
)
args
=
parser
.
parse_args
()
# Specify the model name and benchmark parameters.
model_name
=
'lstm'
parameters
=
'--batch_size 1 --seq_len 256 --precision float32 --num_warmup 8 --num_steps 64 --run_count 2'
if
args
.
distributed
:
parameters
+=
' --distributed_impl ddp --distributed_backend nccl'
# Create context for lstm benchmark and run it for 64 steps.
context
=
BenchmarkRegistry
.
create_benchmark_context
(
model_name
,
platform
=
Platform
.
CUDA
,
parameters
=
parameters
,
framework
=
Framework
.
PYTORCH
)
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
if
benchmark
:
logger
.
info
(
'benchmark: {}, return code: {}, result: {}'
.
format
(
benchmark
.
name
,
benchmark
.
return_code
,
benchmark
.
result
)
)
superbench/benchmarks/model_benchmarks/__init__.py
View file @
2a7ab691
...
@@ -7,5 +7,6 @@ from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
...
@@ -7,5 +7,6 @@ from superbench.benchmarks.model_benchmarks.model_base import ModelBenchmark
from
superbench.benchmarks.model_benchmarks.pytorch_bert
import
PytorchBERT
from
superbench.benchmarks.model_benchmarks.pytorch_bert
import
PytorchBERT
from
superbench.benchmarks.model_benchmarks.pytorch_gpt2
import
PytorchGPT2
from
superbench.benchmarks.model_benchmarks.pytorch_gpt2
import
PytorchGPT2
from
superbench.benchmarks.model_benchmarks.pytorch_cnn
import
PytorchCNN
from
superbench.benchmarks.model_benchmarks.pytorch_cnn
import
PytorchCNN
from
superbench.benchmarks.model_benchmarks.pytorch_lstm
import
PytorchLSTM
__all__
=
[
'ModelBenchmark'
,
'PytorchBERT'
,
'PytorchGPT2'
,
'PytorchCNN'
]
__all__
=
[
'ModelBenchmark'
,
'PytorchBERT'
,
'PytorchGPT2'
,
'PytorchCNN'
,
'PytorchLSTM'
]
superbench/benchmarks/model_benchmarks/pytorch_bert.py
View file @
2a7ab691
...
@@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
...
@@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
class
BertBenchmarkModel
(
torch
.
nn
.
Module
):
class
BertBenchmarkModel
(
torch
.
nn
.
Module
):
"""The BERT model for benchmarking."""
"""The BERT model for benchmarking."""
def
__init__
(
self
,
config
,
num_class
):
def
__init__
(
self
,
config
,
num_class
es
):
"""Constructor.
"""Constructor.
Args:
Args:
config (BertConfig): Configurations of BERT model.
config (BertConfig): Configurations of BERT model.
num_class (int): The number of objects for classification.
num_class
es
(int): The number of objects for classification.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_bert
=
BertModel
(
config
)
self
.
_bert
=
BertModel
(
config
)
self
.
_linear
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
num_class
)
self
.
_linear
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
num_class
es
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
"""Forward propagation function.
"""Forward propagation function.
...
...
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
View file @
2a7ab691
...
@@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
...
@@ -17,16 +17,16 @@ from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDat
class
GPT2BenchmarkModel
(
torch
.
nn
.
Module
):
class
GPT2BenchmarkModel
(
torch
.
nn
.
Module
):
"""The GPT2 model for benchmarking."""
"""The GPT2 model for benchmarking."""
def
__init__
(
self
,
config
,
num_class
):
def
__init__
(
self
,
config
,
num_class
es
):
"""Constructor.
"""Constructor.
Args:
Args:
config (GPT2Config): Configurations of GPT2 model.
config (GPT2Config): Configurations of GPT2 model.
num_class (int): The number of objects for classification.
num_class
es
(int): The number of objects for classification.
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
_bert
=
GPT2Model
(
config
)
self
.
_bert
=
GPT2Model
(
config
)
self
.
_linear
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
num_class
)
self
.
_linear
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
num_class
es
)
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
"""Forward propagation function.
"""Forward propagation function.
...
...
superbench/benchmarks/model_benchmarks/pytorch_lstm.py
0 → 100644
View file @
2a7ab691
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the Pytorch LSTM model."""
import
time
import
torch
from
superbench.common.utils
import
logger
from
superbench.benchmarks
import
BenchmarkRegistry
,
Precision
from
superbench.benchmarks.model_benchmarks.model_base
import
Optimizer
from
superbench.benchmarks.model_benchmarks.pytorch_base
import
PytorchBase
from
superbench.benchmarks.model_benchmarks.random_dataset
import
TorchRandomDataset
class
LSTMBenchmarkModel
(
torch
.
nn
.
Module
):
"""The LSTM model for benchmarking."""
def
__init__
(
self
,
input_size
,
hidden_size
,
num_layers
,
bidirectional
,
num_classes
):
"""Constructor.
Args:
input_size (int): The number of expected features in the input.
hidden_size (int): The number of features in the hidden state.
num_layers (int): The number of recurrent layers.
bidirectional (bool): If True, becomes a bidirectional LSTM.
num_classes (int): The number of objects for classification.
"""
super
().
__init__
()
self
.
_lstm
=
torch
.
nn
.
LSTM
(
input_size
,
hidden_size
,
num_layers
,
batch_first
=
True
,
bidirectional
=
bidirectional
)
self
.
_linear
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
def
forward
(
self
,
input
):
"""Forward propagation function.
Args:
input (torch.FloatTensor): Tensor containing the features of the input sequence,
shape (sequence_length, batch_size, input_size).
Return:
result (torch.FloatTensor): The output features from the last layer of the LSTM
further processed by a Linear layer, shape (batch_size, num_classes).
"""
self
.
_lstm
.
flatten_parameters
()
outputs
=
self
.
_lstm
(
input
)
result
=
self
.
_linear
(
outputs
[
0
][:,
-
1
,
:])
return
result
class
PytorchLSTM
(
PytorchBase
):
"""The LSTM benchmark class."""
def
__init__
(
self
,
name
,
parameters
=
''
):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super
().
__init__
(
name
,
parameters
)
self
.
_config
=
None
self
.
_supported_precision
=
[
Precision
.
FLOAT32
,
Precision
.
FLOAT16
]
self
.
_optimizer_type
=
Optimizer
.
SGD
self
.
_loss_fn
=
torch
.
nn
.
CrossEntropyLoss
()
def
add_parser_arguments
(
self
):
"""Add the LSTM-specified arguments.
LSTM model reference: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
"""
super
().
add_parser_arguments
()
self
.
_parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
100
,
required
=
False
,
help
=
'The number of objects for classification.'
)
self
.
_parser
.
add_argument
(
'--input_size'
,
type
=
int
,
default
=
256
,
required
=
False
,
help
=
'The number of expected features in the input.'
)
self
.
_parser
.
add_argument
(
'--hidden_size'
,
type
=
int
,
default
=
1024
,
required
=
False
,
help
=
'The number of features in the hidden state.'
)
self
.
_parser
.
add_argument
(
'--num_layers'
,
type
=
int
,
default
=
8
,
required
=
False
,
help
=
'The number of recurrent layers.'
)
self
.
_parser
.
add_argument
(
'--bidirectional'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Bidirectional LSTM.'
)
self
.
_parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
512
,
required
=
False
,
help
=
'Sequence length.'
)
def
_generate_dataset
(
self
):
"""Generate dataset for benchmarking according to shape info.
Return:
True if dataset is created successfully.
"""
self
.
_dataset
=
TorchRandomDataset
(
[
self
.
_args
.
sample_count
,
self
.
_args
.
seq_len
,
self
.
_args
.
input_size
],
self
.
_world_size
,
dtype
=
torch
.
float
)
if
len
(
self
.
_dataset
)
==
0
:
logger
.
error
(
'Generate random dataset failed - model: {}'
.
format
(
self
.
_name
))
return
False
return
True
def
_create_model
(
self
,
precision
):
"""Construct the model for benchmarking.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
try
:
self
.
_model
=
LSTMBenchmarkModel
(
self
.
_args
.
input_size
,
self
.
_args
.
hidden_size
,
self
.
_args
.
num_layers
,
self
.
_args
.
bidirectional
,
self
.
_args
.
num_classes
)
self
.
_model
=
self
.
_model
.
to
(
dtype
=
getattr
(
torch
,
precision
.
value
))
if
self
.
_gpu_available
:
self
.
_model
=
self
.
_model
.
cuda
()
except
BaseException
as
e
:
logger
.
error
(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'
.
format
(
self
.
_name
,
precision
,
str
(
e
)
)
)
return
False
self
.
_target
=
torch
.
LongTensor
(
self
.
_args
.
batch_size
).
random_
(
self
.
_args
.
num_classes
)
if
self
.
_gpu_available
:
self
.
_target
=
self
.
_target
.
cuda
()
return
True
def
_train_step
(
self
,
precision
):
"""Define the training process.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
Return:
The step-time list of every training step.
"""
duration
=
[]
curr_step
=
0
while
True
:
for
idx
,
sample
in
enumerate
(
self
.
_dataloader
):
start
=
time
.
time
()
sample
=
sample
.
to
(
dtype
=
getattr
(
torch
,
precision
.
value
))
if
self
.
_gpu_available
:
sample
=
sample
.
cuda
()
self
.
_optimizer
.
zero_grad
()
output
=
self
.
_model
(
sample
)
loss
=
self
.
_loss_fn
(
output
,
self
.
_target
)
loss
.
backward
()
self
.
_optimizer
.
step
()
end
=
time
.
time
()
curr_step
+=
1
if
curr_step
>
self
.
_args
.
num_warmup
:
# Save the step time of every training/inference step, unit is millisecond.
duration
.
append
((
end
-
start
)
*
1000
)
if
self
.
_is_finished
(
curr_step
,
end
):
return
duration
def
_inference_step
(
self
,
precision
):
"""Define the inference process.
Args:
precision (Precision): precision of model and input data,
such as float32, float16.
Return:
The latency list of every inference operation.
"""
duration
=
[]
curr_step
=
0
with
torch
.
no_grad
():
self
.
_model
.
eval
()
while
True
:
for
idx
,
sample
in
enumerate
(
self
.
_dataloader
):
start
=
time
.
time
()
sample
=
sample
.
to
(
dtype
=
getattr
(
torch
,
precision
.
value
))
if
self
.
_gpu_available
:
sample
=
sample
.
cuda
()
self
.
_model
(
sample
)
if
self
.
_gpu_available
:
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
curr_step
+=
1
if
curr_step
>
self
.
_args
.
num_warmup
:
# Save the step time of every training/inference step, unit is millisecond.
duration
.
append
((
end
-
start
)
*
1000
)
if
self
.
_is_finished
(
curr_step
,
end
):
return
duration
# Register LSTM benchmark.
BenchmarkRegistry
.
register_benchmark
(
'pytorch-lstm'
,
PytorchLSTM
,
parameters
=
'--input_size=256 --hidden_size=1024 --num_layers=8'
)
tests/benchmarks/model_benchmarks/test_pytorch_lstm.py
0 → 100644
View file @
2a7ab691
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for LSTM model benchmarks."""
from
tests.helper
import
decorator
from
superbench.benchmarks
import
BenchmarkRegistry
,
Platform
,
Framework
,
BenchmarkType
,
ReturnCode
from
superbench.benchmarks.model_benchmarks.pytorch_lstm
import
PytorchLSTM
@
decorator
.
cuda_test
@
decorator
.
pytorch_test
def
test_pytorch_lstm_with_gpu
():
"""Test pytorch-lstm benchmark with GPU."""
run_pytorch_lstm
(
parameters
=
'--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4
\
--model_action train inference'
,
check_metrics
=
[
'steptime_train_float32'
,
'throughput_train_float32'
,
'steptime_train_float16'
,
'throughput_train_float16'
,
'steptime_inference_float32'
,
'throughput_inference_float32'
,
'steptime_inference_float16'
,
'throughput_inference_float16'
]
)
@
decorator
.
pytorch_test
def
test_pytorch_lstm_no_gpu
():
"""Test pytorch-lstm benchmark with CPU."""
run_pytorch_lstm
(
parameters
=
'--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4
\
--model_action train inference --precision float32 --no_gpu'
,
check_metrics
=
[
'steptime_train_float32'
,
'throughput_train_float32'
,
'steptime_inference_float32'
,
'throughput_inference_float32'
]
)
def
run_pytorch_lstm
(
parameters
=
''
,
check_metrics
=
[]):
"""Test pytorch-lstm benchmark."""
context
=
BenchmarkRegistry
.
create_benchmark_context
(
'lstm'
,
platform
=
Platform
.
CUDA
,
parameters
=
parameters
,
framework
=
Framework
.
PYTORCH
)
assert
(
BenchmarkRegistry
.
is_benchmark_context_valid
(
context
))
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
# Check basic information.
assert
(
benchmark
)
assert
(
isinstance
(
benchmark
,
PytorchLSTM
))
assert
(
benchmark
.
name
==
'pytorch-lstm'
)
assert
(
benchmark
.
type
==
BenchmarkType
.
MODEL
)
# Check predefined parameters of lstm model.
assert
(
benchmark
.
_args
.
input_size
==
256
)
assert
(
benchmark
.
_args
.
hidden_size
==
1024
)
assert
(
benchmark
.
_args
.
num_layers
==
8
)
# Check parameters specified in BenchmarkContext.
assert
(
benchmark
.
_args
.
batch_size
==
1
)
assert
(
benchmark
.
_args
.
num_classes
==
5
)
assert
(
benchmark
.
_args
.
seq_len
==
8
)
assert
(
benchmark
.
_args
.
num_warmup
==
2
)
assert
(
benchmark
.
_args
.
num_steps
==
4
)
# Check dataset scale.
assert
(
len
(
benchmark
.
_dataset
)
==
benchmark
.
_args
.
sample_count
*
benchmark
.
_world_size
)
# Check results and metrics.
assert
(
benchmark
.
run_count
==
1
)
assert
(
benchmark
.
return_code
==
ReturnCode
.
SUCCESS
)
for
metric
in
check_metrics
:
assert
(
len
(
benchmark
.
raw_data
[
metric
])
==
benchmark
.
run_count
)
assert
(
len
(
benchmark
.
raw_data
[
metric
][
0
])
==
benchmark
.
_args
.
num_steps
)
assert
(
len
(
benchmark
.
result
[
metric
])
==
benchmark
.
run_count
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment