"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "fe95bc332d92c6e3f5c2e07fd681bd3549b77374"
Unverified Commit cb8a3cfb authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Benchmarks - Add transformers for TensorRT inference (#254)

Add transformers for TensorRT inference.
parent 10012a0a
......@@ -71,12 +71,16 @@ TODO
#### Introduction
Inference PyTorch/ONNX models on NVIDIA GPUs with [TensorRT](https://developer.nvidia.com/tensorrt).
Currently the following models are supported:
> alexnet, densenet121, densenet169, densenet201, densenet161, googlenet, inception_v3, mnasnet0_5,
> mnasnet1_0, mobilenet_v2, resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d,
> resnext101_32x8d, wide_resnet50_2, wide_resnet101_2, shufflenet_v2_x0_5, shufflenet_v2_x1_0,
> squeezenet1_0, squeezenet1_1, vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19
> lstm, bert-base, bert-large, gpt2-small
> Do not support large models like `gpt2-large` currently because models larger than 2GB (maximum protobuf size) cannot be exported in one ONNX file.
#### Metrics
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Export PyTorch models to ONNX format."""
from pathlib import Path
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
class torch2onnxExporter():
"""PyTorch model to ONNX exporter."""
def __init__(self):
"""Constructor."""
self.num_classes = 100
self.lstm_input_size = 256
self.benchmark_models = {
'lstm':
lambda: LSTMBenchmarkModel(
self.lstm_input_size,
1024,
8,
False,
self.num_classes,
),
'bert-base':
lambda: BertBenchmarkModel(
BertConfig(
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
),
self.num_classes,
),
'bert-large':
lambda: BertBenchmarkModel(
BertConfig(
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
),
self.num_classes,
),
'gpt2-small':
lambda: GPT2BenchmarkModel(
GPT2Config(
n_embd=768,
n_layer=12,
n_head=12,
),
self.num_classes,
),
'gpt2-medium':
lambda: GPT2BenchmarkModel(
GPT2Config(
n_embd=1024,
n_layer=24,
n_head=16,
),
self.num_classes,
),
'gpt2-large':
lambda: GPT2BenchmarkModel(
GPT2Config(
n_embd=1280,
n_layer=36,
n_head=20,
),
self.num_classes,
),
'gpt2-xl':
lambda: GPT2BenchmarkModel(
GPT2Config(
n_embd=1600,
n_layer=48,
n_head=25,
),
self.num_classes,
),
}
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
def check_torchvision_model(self, model_name):
"""Check whether can export the torchvision model with given name.
Args:
model_name (str): Name of torchvision model to check.
Returns:
bool: True if the model can be exported, False otherwise.
"""
if hasattr(torchvision.models, model_name):
return True
return False
def check_benchmark_model(self, model_name):
"""Check whether can export the benchmark model with given name.
Args:
model_name (str): Name of benchmark model to check.
Returns:
bool: True if the model can be exported, False otherwise.
"""
if model_name in self.benchmark_models:
return True
return False
def export_torchvision_model(self, model_name, batch_size=1):
"""Export the torchvision model with given name.
Args:
model_name (str): Name of torchvision model to export.
batch_size (int): Batch size of input. Defaults to 1.
Returns:
str: Exported ONNX model file name.
"""
if not self.check_torchvision_model(model_name):
return ''
file_name = str(self._onnx_model_path / (model_name + '.onnx'))
input_shape = (batch_size, 3, 224, 224)
torch.onnx.export(
getattr(torchvision.models, model_name)(pretrained=False).eval().cuda(),
torch.randn(input_shape, device='cuda'),
file_name,
opset_version=10,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {
0: 'batch_size',
},
'output': {
0: 'batch_size',
}
},
)
return file_name
def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
"""Export the benchmark model with given name.
Args:
model_name (str): Name of benchmark model to export.
batch_size (int): Batch size of input. Defaults to 1.
seq_length (int): Sequence length of input. Defaults to 512.
Returns:
str: Exported ONNX model file name.
"""
if not self.check_benchmark_model(model_name):
return
file_name = str(self._onnx_model_path / (model_name + '.onnx'))
input_shape, dtype = (batch_size, seq_length), torch.int64
if model_name == 'lstm':
input_shape += (self.lstm_input_size, )
dtype = None
torch.onnx.export(
self.benchmark_models[model_name]().eval().cuda(),
torch.ones(input_shape, dtype=dtype, device='cuda'),
file_name,
opset_version=10,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {
0: 'batch_size',
1: 'seq_length',
},
'output': {
0: 'batch_size',
}
},
)
return file_name
......@@ -6,13 +6,10 @@
import re
from pathlib import Path
import torch.hub
import torch.onnx
import torchvision.models
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode
from superbench.benchmarks.micro_benchmarks import MicroBenchmarkWithInvoke
from superbench.benchmarks.micro_benchmarks._export_torch_to_onnx import torch2onnxExporter
class TensorRTInferenceBenchmark(MicroBenchmarkWithInvoke):
......@@ -27,18 +24,7 @@ def __init__(self, name, parameters=''):
super().__init__(name, parameters)
self._bin_name = 'trtexec'
self._pytorch_models = [
'resnet50',
'resnet101',
'resnet152',
'densenet169',
'densenet201',
'vgg11',
'vgg13',
'vgg16',
'vgg19',
]
self.__model_cache_path = Path(torch.hub.get_dir()) / 'checkpoints'
self._pytorch_models = ['resnet50']
def add_parser_arguments(self):
"""Add the specified arguments."""
......@@ -66,13 +52,21 @@ def add_parser_arguments(self):
type=int,
default=32,
required=False,
help='Set batch size for implicit batch engines.',
help='Set batch size for inference input.',
)
self._parser.add_argument(
'--seq_length',
type=int,
default=512,
required=False,
help='Set sequence length for inference input, only effective for transformers',
)
self._parser.add_argument(
'--iterations',
type=int,
default=256,
default=2048,
required=False,
help='Run at least N inference iterations.',
)
......@@ -88,31 +82,37 @@ def _preprocess(self):
self.__bin_path = str(Path(self._args.bin_dir) / self._bin_name)
exporter = torch2onnxExporter()
for model in self._args.pytorch_models:
if hasattr(torchvision.models, model):
torch.onnx.export(
getattr(torchvision.models, model)(pretrained=True).cuda(),
torch.randn(self._args.batch_size, 3, 224, 224, device='cuda'),
f'{self.__model_cache_path / (model + ".onnx")}',
)
self._commands.append(
' '.join(
filter(
None, [
self.__bin_path,
None if self._args.precision == 'fp32' else f'--{self._args.precision}',
f'--batch={self._args.batch_size}',
f'--iterations={self._args.iterations}',
'--workspace=1024',
'--percentile=99',
f'--onnx={self.__model_cache_path / (model + ".onnx")}',
]
)
)
)
else:
if not (exporter.check_torchvision_model(model) or exporter.check_benchmark_model(model)):
logger.error('Cannot find PyTorch model %s.', model)
return False
for model in self._args.pytorch_models:
input_shape: str
onnx_model: str
if exporter.check_torchvision_model(model):
input_shape = f'{self._args.batch_size}x3x224x224'
onnx_model = exporter.export_torchvision_model(model, self._args.batch_size)
if exporter.check_benchmark_model(model):
input_shape = f'{self._args.batch_size}x{self._args.seq_length}'
onnx_model = exporter.export_benchmark_model(model, self._args.batch_size, self._args.seq_length)
args = [
# trtexec
self.__bin_path,
# model options
f'--onnx={onnx_model}',
# build options
'--explicitBatch',
f'--optShapes=input:{input_shape}',
'--workspace=8192',
None if self._args.precision == 'fp32' else f'--{self._args.precision}',
# inference options
f'--iterations={self._args.iterations}',
# reporting options
'--percentile=99',
] # yapf: disable
self._commands.append(' '.join(filter(None, args)))
return True
def _process_raw_result(self, cmd_idx, raw_output):
......
......@@ -28,14 +28,14 @@ def __init__(self, input_size, hidden_size, num_layers, bidirectional, num_class
"""
super().__init__()
self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
self._linear = torch.nn.Linear(hidden_size, num_classes)
self._linear = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), num_classes)
def forward(self, input):
"""Forward propagation function.
Args:
input (torch.FloatTensor): Tensor containing the features of the input sequence,
shape (sequence_length, batch_size, input_size).
shape (batch_size, sequence_length, input_size).
Return:
result (torch.FloatTensor): The output features from the last layer of the LSTM
......
......@@ -21,9 +21,10 @@ def setUp(self):
"""Hook method for setting up the test fixture before exercising it."""
self.benchmark_name = 'tensorrt-inference'
self.__tmp_dir = tempfile.mkdtemp()
self.__model_path = Path(self.__tmp_dir) / 'hub' / 'onnx'
self.__curr_micro_path = os.environ.get('SB_MICRO_PATH', '')
os.environ['SB_MICRO_PATH'] = self.__tmp_dir
os.environ['TORCH_HOME'] = self.__tmp_dir
os.environ['SB_MICRO_PATH'] = self.__tmp_dir
(Path(self.__tmp_dir) / 'bin').mkdir(parents=True, exist_ok=True)
(Path(self.__tmp_dir) / 'bin' / 'trtexec').touch(mode=0o755, exist_ok=True)
......@@ -61,8 +62,10 @@ def test_tensorrt_inference_params(self):
'batch_size': 4,
},
{
'pytorch_models': ['lstm', 'bert-base', 'gpt2-small'],
'batch_size': 4,
'iterations': 128,
'seq_length': 128,
'iterations': 256,
},
]
for test_case in test_cases:
......@@ -74,6 +77,8 @@ def test_tensorrt_inference_params(self):
parameter_list.append(f'--precision {test_case["precision"]}')
if 'batch_size' in test_case:
parameter_list.append(f'--batch_size {test_case["batch_size"]}')
if 'seq_length' in test_case:
parameter_list.append(f'--seq_length {test_case["seq_length"]}')
if 'iterations' in test_case:
parameter_list.append(f'--iterations {test_case["iterations"]}')
......@@ -83,7 +88,6 @@ def test_tensorrt_inference_params(self):
# Limit model number
benchmark._pytorch_models = benchmark._pytorch_models[:1]
benchmark._TensorRTInferenceBenchmark__model_cache_path = Path(self.__tmp_dir) / 'hub/checkpoints'
# Preprocess
ret = benchmark._preprocess()
......@@ -106,15 +110,13 @@ def test_tensorrt_inference_params(self):
benchmark._args.batch_size,
)
self.assertEqual(
test_case.get('iterations', 256),
test_case.get('iterations', 2048),
benchmark._args.iterations,
)
# Check models
for model in benchmark._args.pytorch_models:
self.assertTrue(
(benchmark._TensorRTInferenceBenchmark__model_cache_path / f'{model}.onnx').is_file()
)
self.assertTrue((self.__model_path / f'{model}.onnx').is_file())
# Command list should equal to default model number
self.assertEqual(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment