Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tsoc
superbenchmark
Commits
af567cf6
Unverified
Commit
af567cf6
authored
Apr 16, 2021
by
guoshzhao
Committed by
GitHub
Apr 16, 2021
Browse files
Benchmarks: Add Benchmark - Add GPT2 model benchmark. (#57)
* Benchmarks: Add Benchmark - Add GPT2 model benchmark.
parent
fb850af7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
306 additions
and
1 deletion
+306
-1
examples/benchmarks/pytorch_gpt2_large.py
examples/benchmarks/pytorch_gpt2_large.py
+41
-0
superbench/benchmarks/model_benchmarks/__init__.py
superbench/benchmarks/model_benchmarks/__init__.py
+2
-1
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
+205
-0
tests/benchmarks/model_benchmarks/test_pytorch_gpt2.py
tests/benchmarks/model_benchmarks/test_pytorch_gpt2.py
+58
-0
No files found.
examples/benchmarks/pytorch_gpt2_large.py
0 → 100644
View file @
af567cf6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Model benchmark example for gpt2-large (36-layer, 1280-hidden, 20-heads, 774M parameters).
Commands to run:
python3 examples/benchmarks/pytorch_gpt2_large.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_gpt2_large.py
\
--distributed (Distributed)
"""
import
argparse
from
superbench.benchmarks
import
Platform
,
Framework
,
BenchmarkRegistry
from
superbench.common.utils
import
logger
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--distributed'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to enable distributed training.'
)
args
=
parser
.
parse_args
()
# Specify the model name and benchmark parameters.
model_name
=
'gpt2-large'
parameters
=
'--batch_size 1 --duration 120 --seq_len 128 --precision float32 --run_count 2'
if
args
.
distributed
:
parameters
+=
' --distributed_impl ddp --distributed_backend nccl'
# Create context for gpt2-large benchmark and run it for 120 * 2 seconds.
context
=
BenchmarkRegistry
.
create_benchmark_context
(
model_name
,
platform
=
Platform
.
CUDA
,
parameters
=
parameters
,
framework
=
Framework
.
PYTORCH
)
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
if
benchmark
:
logger
.
info
(
'benchmark: {}, return code: {}, result: {}'
.
format
(
benchmark
.
name
,
benchmark
.
return_code
,
benchmark
.
result
)
)
superbench/benchmarks/model_benchmarks/__init__.py
View file @
af567cf6
...
...
@@ -5,5 +5,6 @@
from
superbench.benchmarks.model_benchmarks.model_base
import
ModelBenchmark
from
superbench.benchmarks.model_benchmarks.pytorch_bert
import
PytorchBERT
from
superbench.benchmarks.model_benchmarks.pytorch_gpt2
import
PytorchGPT2
__all__
=
[
'ModelBenchmark'
,
'PytorchBERT'
]
__all__
=
[
'ModelBenchmark'
,
'PytorchBERT'
,
'PytorchGPT2'
]
superbench/benchmarks/model_benchmarks/pytorch_gpt2.py
0 → 100644
View file @
af567cf6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the Pytorch GPT2 model."""
import
time
import
torch
from
transformers
import
GPT2Model
,
GPT2Config
from
superbench.common.utils
import
logger
from
superbench.benchmarks
import
BenchmarkRegistry
,
Precision
from
superbench.benchmarks.model_benchmarks.model_base
import
Optimizer
from
superbench.benchmarks.model_benchmarks.pytorch_base
import
PytorchBase
from
superbench.benchmarks.model_benchmarks.random_dataset
import
TorchRandomDataset
class
GPT2BenchmarkModel
(
torch
.
nn
.
Module
):
"""The GPT2 model for benchmarking."""
def
__init__
(
self
,
config
,
num_class
):
"""Constructor.
Args:
config (GPT2Config): Configurations of GPT2 model.
num_class (int): The number of objects for classification.
"""
super
().
__init__
()
self
.
_bert
=
GPT2Model
(
config
)
self
.
_linear
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
num_class
)
def
forward
(
self
,
input
):
"""Forward propagation function.
Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).
Return:
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
outputs
=
self
.
_bert
(
input
)
result
=
self
.
_linear
(
outputs
[
0
])
return
result
class
PytorchGPT2
(
PytorchBase
):
"""The GPT2 benchmark class."""
def
__init__
(
self
,
name
,
parameters
=
''
):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super
().
__init__
(
name
,
parameters
)
self
.
_config
=
None
self
.
_supported_precision
=
[
Precision
.
FLOAT32
,
Precision
.
FLOAT16
]
self
.
_optimizer_type
=
Optimizer
.
ADAMW
self
.
_loss_fn
=
torch
.
nn
.
CrossEntropyLoss
()
def
add_parser_arguments
(
self
):
"""Add the GPT2-specified arguments.
GPT2 model reference: https://huggingface.co/transformers/model_doc/gpt2.html
"""
super
().
add_parser_arguments
()
self
.
_parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
100
,
required
=
False
,
help
=
'Num of class.'
)
self
.
_parser
.
add_argument
(
'--hidden_size'
,
type
=
int
,
default
=
1280
,
required
=
False
,
help
=
'Hidden size.'
)
self
.
_parser
.
add_argument
(
'--num_hidden_layers'
,
type
=
int
,
default
=
36
,
required
=
False
,
help
=
'The number of hidden layers.'
)
self
.
_parser
.
add_argument
(
'--num_attention_heads'
,
type
=
int
,
default
=
20
,
required
=
False
,
help
=
'The number of attention heads.'
)
self
.
_parser
.
add_argument
(
'--seq_len'
,
type
=
int
,
default
=
512
,
required
=
False
,
help
=
'Sequence length.'
)
def
_generate_dataset
(
self
):
"""Generate dataset for benchmarking according to shape info.
Return:
True if dataset is created successfully.
"""
self
.
_dataset
=
TorchRandomDataset
(
[
self
.
_args
.
sample_count
,
self
.
_args
.
seq_len
],
self
.
_world_size
,
dtype
=
torch
.
long
)
if
len
(
self
.
_dataset
)
==
0
:
logger
.
error
(
'Generate random dataset failed - model: {}'
.
format
(
self
.
_name
))
return
False
return
True
def
_create_model
(
self
,
precision
):
"""Construct the model for benchmarking.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
self
.
_config
=
GPT2Config
(
n_embd
=
self
.
_args
.
hidden_size
,
n_layer
=
self
.
_args
.
num_hidden_layers
,
n_head
=
self
.
_args
.
num_attention_heads
)
try
:
self
.
_model
=
GPT2BenchmarkModel
(
self
.
_config
,
self
.
_args
.
num_classes
)
self
.
_model
=
self
.
_model
.
to
(
dtype
=
getattr
(
torch
,
precision
.
value
))
if
self
.
_gpu_available
:
self
.
_model
=
self
.
_model
.
cuda
()
except
BaseException
as
e
:
logger
.
error
(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'
.
format
(
self
.
_name
,
precision
,
str
(
e
)
)
)
return
False
self
.
_target
=
torch
.
LongTensor
(
self
.
_args
.
batch_size
).
random_
(
self
.
_args
.
num_classes
)
if
self
.
_gpu_available
:
self
.
_target
=
self
.
_target
.
cuda
()
return
True
def
_train_step
(
self
,
precision
):
"""Define the training process.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
Return:
The step-time list of every training step.
"""
duration
=
[]
curr_step
=
0
while
True
:
for
idx
,
sample
in
enumerate
(
self
.
_dataloader
):
start
=
time
.
time
()
if
self
.
_gpu_available
:
sample
=
sample
.
cuda
()
self
.
_optimizer
.
zero_grad
()
output
=
self
.
_model
(
sample
)
loss
=
self
.
_loss_fn
(
output
[
range
(
self
.
_args
.
batch_size
),
-
1
],
self
.
_target
)
loss
.
backward
()
self
.
_optimizer
.
step
()
end
=
time
.
time
()
curr_step
+=
1
if
curr_step
>
self
.
_args
.
num_warmup
:
# Save the step time of every training/inference step, unit is millisecond.
duration
.
append
((
end
-
start
)
*
1000
)
if
self
.
_is_finished
(
curr_step
,
end
):
return
duration
def
_inference_step
(
self
,
precision
):
"""Define the inference process.
Args:
precision (Precision): precision of model and input data,
such as float32, float16.
Return:
The latency list of every inference operation.
"""
duration
=
[]
curr_step
=
0
with
torch
.
no_grad
():
self
.
_model
.
eval
()
while
True
:
for
idx
,
sample
in
enumerate
(
self
.
_dataloader
):
start
=
time
.
time
()
if
self
.
_gpu_available
:
sample
=
sample
.
cuda
()
self
.
_model
(
sample
)
if
self
.
_gpu_available
:
torch
.
cuda
.
synchronize
()
end
=
time
.
time
()
curr_step
+=
1
if
curr_step
>
self
.
_args
.
num_warmup
:
# Save the step time of every training/inference step, unit is millisecond.
duration
.
append
((
end
-
start
)
*
1000
)
if
self
.
_is_finished
(
curr_step
,
end
):
return
duration
# Register GPT2 benchmark with 117M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry
.
register_benchmark
(
'pytorch-gpt2-small'
,
PytorchGPT2
,
parameters
=
'--hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12'
)
# Register GPT2 benchmark with 345M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry
.
register_benchmark
(
'pytorch-gpt2-medium'
,
PytorchGPT2
,
parameters
=
'--hidden_size=1024 --num_hidden_layers=24 --num_attention_heads=16'
)
# Register GPT2 benchmark with 774M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry
.
register_benchmark
(
'pytorch-gpt2-large'
,
PytorchGPT2
,
parameters
=
'--hidden_size=1280 --num_hidden_layers=36 --num_attention_heads=20'
)
# Register GPT2 benchmark with 1558M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry
.
register_benchmark
(
'pytorch-gpt2-xl'
,
PytorchGPT2
,
parameters
=
'--hidden_size=1600 --num_hidden_layers=48 --num_attention_heads=25'
)
tests/benchmarks/model_benchmarks/test_pytorch_gpt2.py
0 → 100644
View file @
af567cf6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for GPT2 model benchmarks."""
from
tests.helper
import
decorator
from
superbench.benchmarks
import
BenchmarkRegistry
,
Platform
,
Framework
,
BenchmarkType
,
ReturnCode
from
superbench.benchmarks.model_benchmarks.pytorch_gpt2
import
PytorchGPT2
@
decorator
.
cuda_test
@
decorator
.
pytorch_test
def
test_pytorch_gpt2_small
():
"""Test pytorch-gpt2-small benchmark."""
context
=
BenchmarkRegistry
.
create_benchmark_context
(
'gpt2-small'
,
platform
=
Platform
.
CUDA
,
parameters
=
'--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4
\
--model_action train inference'
,
framework
=
Framework
.
PYTORCH
)
assert
(
BenchmarkRegistry
.
is_benchmark_context_valid
(
context
))
benchmark
=
BenchmarkRegistry
.
launch_benchmark
(
context
)
# Check basic information.
assert
(
benchmark
)
assert
(
isinstance
(
benchmark
,
PytorchGPT2
))
assert
(
benchmark
.
name
==
'pytorch-gpt2-small'
)
assert
(
benchmark
.
type
==
BenchmarkType
.
MODEL
)
# Check predefined parameters of gpt2-large model.
assert
(
benchmark
.
_args
.
hidden_size
==
768
)
assert
(
benchmark
.
_args
.
num_hidden_layers
==
12
)
assert
(
benchmark
.
_args
.
num_attention_heads
==
12
)
# Check parameters specified in BenchmarkContext.
assert
(
benchmark
.
_args
.
batch_size
==
1
)
assert
(
benchmark
.
_args
.
num_classes
==
5
)
assert
(
benchmark
.
_args
.
seq_len
==
8
)
assert
(
benchmark
.
_args
.
num_warmup
==
2
)
assert
(
benchmark
.
_args
.
num_steps
==
4
)
# Test Dataset.
assert
(
len
(
benchmark
.
_dataset
)
==
benchmark
.
_args
.
sample_count
*
benchmark
.
_world_size
)
# Check results and metrics.
assert
(
benchmark
.
run_count
==
1
)
assert
(
benchmark
.
return_code
==
ReturnCode
.
SUCCESS
)
for
metric
in
[
'steptime_train_float32'
,
'throughput_train_float32'
,
'steptime_train_float16'
,
'throughput_train_float16'
,
'steptime_inference_float32'
,
'throughput_inference_float32'
,
'steptime_inference_float16'
,
'throughput_inference_float16'
]:
assert
(
len
(
benchmark
.
raw_data
[
metric
])
==
benchmark
.
run_count
)
assert
(
len
(
benchmark
.
raw_data
[
metric
][
0
])
==
benchmark
.
_args
.
num_steps
)
assert
(
len
(
benchmark
.
result
[
metric
])
==
benchmark
.
run_count
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment