Unverified Commit 249e21c1 authored by pdr's avatar pdr Committed by GitHub
Browse files

Benchmarks - Add LLaMA-2 Models (#668)

Added llama benchmark - training and inference in accordance with the
existing pytorch models implementation like gpt2, lstm etc.

- added llama fp8 unit test for better code coverage, to reduce memory
required
- updated transformers version >= 4.28.0 for LLamaConfig
- set tokenizers version <= 0.20.3 to avoid 0.20.4 version
[issues](https://github.com/huggingface/tokenizers/issues/1691

) with
py3.8
- added llama2 to tensorrt
- llama2 tests not added to test_tensorrt_inference_performance.py due
to large memory requirement for worker gpu. tests validated separately
on gh200

---------
Co-authored-by: default avatardpatlolla <dpatlolla@microsoft.com>
parent 4e6935ab
...@@ -328,7 +328,8 @@ A list of models to run, only supported in model-benchmark. ...@@ -328,7 +328,8 @@ A list of models to run, only supported in model-benchmark.
shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 | shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 |
squeezenet1_0 | squeezenet1_1 | squeezenet1_0 | squeezenet1_1 |
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 | vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19_bn | vgg19 |
bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl ] bert-base | bert-large | gpt2-small | gpt2-medium | gpt2-large | gpt2-xl |
llama2-7b | llama2-13b | llama2-70b ]
``` ```
* default value: `[ ]` * default value: `[ ]`
......
...@@ -13,6 +13,7 @@ id: model-benchmarks ...@@ -13,6 +13,7 @@ id: model-benchmarks
Run training or inference tasks with single or half precision for deep learning models, Run training or inference tasks with single or half precision for deep learning models,
including the following categories: including the following categories:
* GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl * GPT: gpt2-small, gpt2-medium, gpt2-large and gpt2-xl
* LLAMA: llama2-7b, llama2-13b, llama2-70b
* BERT: bert-base and bert-large * BERT: bert-base and bert-large
* LSTM * LSTM
* CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including: * CNN, listed in [`torchvision.models`](https://pytorch.org/vision/0.8/models.html), including:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Model benchmark example for Llama2-7b (32-layer, 4096-hidden, 32-heads, 7B parameters).
Commands to run:
python3 examples/benchmarks/pytorch_llama2.py (Single GPU)
python3 -m torch.distributed.launch --use_env --nproc_per_node=8 examples/benchmarks/pytorch_llama2.py \
--distributed (Distributed)
"""
import argparse
from superbench.benchmarks import Platform, Framework, BenchmarkRegistry
from superbench.common.utils import logger
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--distributed', action='store_true', default=False, help='Whether to enable distributed training.'
)
args = parser.parse_args()
# Specify the model name and benchmark parameters.
model_name = 'llama2-7b'
parameters = '--batch_size 1 --duration 120 --seq_len 512 --precision float16'
if args.distributed:
parameters += ' --distributed_impl ddp --distributed_backend nccl'
# Create context for Llama2 benchmark and run it for 120 seconds.
context = BenchmarkRegistry.create_benchmark_context(
model_name, platform=Platform.CUDA, parameters=parameters, framework=Framework.PYTORCH
)
benchmark = BenchmarkRegistry.launch_benchmark(context)
if benchmark:
logger.info(
'benchmark: {}, return code: {}, result: {}'.format(
benchmark.name, benchmark.return_code, benchmark.result
)
)
...@@ -209,9 +209,10 @@ def run(self): ...@@ -209,9 +209,10 @@ def run(self):
'yapf==0.31.0', 'yapf==0.31.0',
], ],
'torch': [ 'torch': [
'tokenizers<=0.20.3',
'torch>=1.7.0a0', 'torch>=1.7.0a0',
'torchvision>=0.8.0a0', 'torchvision>=0.8.0a0',
'transformers>=4.3.3, <4.23.0', 'transformers>=4.28.0',
], ],
'ort': [ 'ort': [
'onnx>=1.10.2', 'onnx>=1.10.2',
......
...@@ -89,7 +89,8 @@ def get_configurable_settings(self): ...@@ -89,7 +89,8 @@ def get_configurable_settings(self):
Return: Return:
All configurable settings in raw string. All configurable settings in raw string.
""" """
return self._parser.format_help().strip() message = self._parser.format_help().strip()
return message
def parse_args(self, ignore_invalid=False): def parse_args(self, ignore_invalid=False):
"""Parse the arguments. """Parse the arguments.
......
...@@ -9,11 +9,12 @@ ...@@ -9,11 +9,12 @@
import torch.hub import torch.hub
import torch.onnx import torch.onnx
import torchvision.models import torchvision.models
from transformers import BertConfig, GPT2Config from transformers import BertConfig, GPT2Config, LlamaConfig
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
class torch2onnxExporter(): class torch2onnxExporter():
...@@ -87,6 +88,39 @@ def __init__(self): ...@@ -87,6 +88,39 @@ def __init__(self):
), ),
self.num_classes, self.num_classes,
), ),
'llama2-7b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
intermediate_size=11008,
),
self.num_classes,
),
'llama2-13b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=5120,
num_hidden_layers=40,
num_attention_heads=40,
num_key_value_heads=40,
intermediate_size=13824,
),
self.num_classes,
),
'llama2-70b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=8192,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
intermediate_size=28672,
),
self.num_classes,
),
} }
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx' self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True) self._onnx_model_path.mkdir(parents=True, exist_ok=True)
...@@ -138,7 +172,7 @@ def export_torchvision_model(self, model_name, batch_size=1): ...@@ -138,7 +172,7 @@ def export_torchvision_model(self, model_name, batch_size=1):
model, model,
dummy_input, dummy_input,
file_name, file_name,
opset_version=10, opset_version=14,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
input_names=['input'], input_names=['input'],
output_names=['output'], output_names=['output'],
...@@ -179,7 +213,7 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512): ...@@ -179,7 +213,7 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
model, model,
dummy_input, dummy_input,
file_name, file_name,
opset_version=10, opset_version=14,
do_constant_folding=True, do_constant_folding=True,
input_names=['input'], input_names=['input'],
output_names=['output'], output_names=['output'],
......
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM from superbench.benchmarks.model_benchmarks.pytorch_lstm import PytorchLSTM
from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT from superbench.benchmarks.model_benchmarks.megatron_gpt3 import MegatronGPT
__all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT'] __all__ = ['ModelBenchmark', 'PytorchBERT', 'PytorchGPT2', 'PytorchCNN', 'PytorchLSTM', 'MegatronGPT', 'PytorchLlama']
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Module of the Pytorch Llama2 model."""
import torch
from transformers import LlamaModel, LlamaConfig
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
te = None
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset
class LlamaBenchmarkModel(torch.nn.Module):
"""The Llama model for benchmarking."""
def __init__(self, config, num_classes):
"""Constructor.
Args:
config (LlamaConfig): Configurations of Llama model.
num_classes (int): The number of objects for classification.
"""
super().__init__()
self._llama = LlamaModel(config)
self._linear = torch.nn.Linear(config.hidden_size, num_classes)
def forward(self, input):
"""Forward propagation function.
Args:
input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
shape (batch_size, sequence_length).
Return:
result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
(classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
"""
outputs = self._llama(input)
result = self._linear(outputs[0])
return result
class PytorchLlama(PytorchBase):
"""The Llama benchmark class."""
def __init__(self, name, parameters=''):
"""Constructor.
Args:
name (str): benchmark name.
parameters (str): benchmark parameters.
"""
super().__init__(name, parameters)
self._config = None
self._fp8_recipe = None
self._supported_precision = [
Precision.FLOAT32,
Precision.FLOAT16,
Precision.FP8_HYBRID,
Precision.FP8_E4M3,
]
self._optimizer_type = Optimizer.ADAMW
self._loss_fn = torch.nn.CrossEntropyLoss()
def add_parser_arguments(self):
"""Add the Llama-specified arguments.
Llama2 model reference: https://huggingface.co/docs/transformers/model_doc/llama2
"""
super().add_parser_arguments()
self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
self._parser.add_argument('--hidden_size', type=int, default=1280, required=False, help='Hidden size.')
self._parser.add_argument(
'--num_hidden_layers', type=int, default=36, required=False, help='The number of hidden layers.'
)
self._parser.add_argument(
'--num_attention_heads', type=int, default=20, required=False, help='The number of attention heads.'
)
self._parser.add_argument(
'--intermediate_size',
type=int,
default=11008,
required=False,
help='Dimension of the MLP representations.'
)
self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
self._parser.add_argument(
'--num_key_value_heads',
type=int,
default=None,
required=False,
help='The number of key_value heads that should be used to implement Grouped Query Attention.'
)
def _generate_dataset(self):
"""Generate dataset for benchmarking according to shape info.
Return:
True if dataset is created successfully.
"""
self._dataset = TorchRandomDataset(
[self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
return False
return True
def _create_model(self, precision):
"""Construct the model for benchmarking.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
"""
self._config = LlamaConfig(
hidden_size=self._args.hidden_size,
num_hidden_layers=self._args.num_hidden_layers,
num_attention_heads=self._args.num_attention_heads,
num_key_value_heads=self._args.num_key_value_heads,
intermediate_size=self._args.intermediate_size,
max_position_embeddings=4096, # Maximum sequence length that llama2 supports
rms_norm_eps=1e-05, # Llama2 default for epsilon used by the rms normalization layers
)
enable_fp8 = precision.name.startswith('FP8_')
if enable_fp8 and te is None:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: Cannot find transformer_engine.'
)
return False
if enable_fp8 and not self._gpu_available:
logger.error(
f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
' message: FP8 is only supported on GPU.'
)
return False
try:
self._model = LlamaBenchmarkModel(self._config, self._args.num_classes)
if enable_fp8:
self._fp8_recipe = DelayedScaling(
fp8_format=Format[precision.name.strip('FP8_')],
amax_history_len=16,
amax_compute_algo='max',
)
self._to_te_model(self._model.to(dtype=torch.float16))
else:
self._model = self._model.to(dtype=getattr(torch, precision.value))
if self._gpu_available:
self._model = self._model.cuda()
except BaseException as e:
logger.error(
'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
self._name, precision, str(e)
)
)
return False
self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
if self._gpu_available:
self._target = self._target.cuda()
return True
def _train_step(self, precision):
"""Define the training process.
Args:
precision (Precision): precision of model and input data, such as float32, float16.
Return:
The step-time list of every training step.
"""
duration = []
curr_step = 0
check_frequency = 100
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
self._optimizer.zero_grad()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
output = self._model(sample)
else:
output = self._model(sample)
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
loss.backward()
self._optimizer.step()
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end, check_frequency):
return duration
def _inference_step(self, precision):
"""Define the inference process.
Args:
precision (Precision): precision of model and input data,
such as float32, float16.
Return:
The latency list of every inference operation.
"""
duration = []
curr_step = 0
with torch.no_grad():
self._model.eval()
while True:
for idx, sample in enumerate(self._dataloader):
start = self._timer()
if self._gpu_available:
sample = sample.cuda()
if self._fp8_recipe is not None:
with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
self._model(sample)
else:
self._model(sample)
end = self._timer()
curr_step += 1
if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000)
self._log_step_time(curr_step, precision, duration)
if self._is_finished(curr_step, end):
return duration
# Register Llama2 benchmark with 7b parameters.
BenchmarkRegistry.register_benchmark(
'pytorch-llama2-7b',
PytorchLlama,
parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --num_key_value_heads=32 \
--intermediate_size=11008'
)
# Register Llama2 benchmark with 13b parameters.
BenchmarkRegistry.register_benchmark(
'pytorch-llama2-13b',
PytorchLlama,
parameters='--hidden_size=5120 --num_hidden_layers=40 --num_attention_heads=40 --num_key_value_heads=40 \
--intermediate_size=13824'
)
# Register Llama2 benchmark with 70b parameters.
BenchmarkRegistry.register_benchmark(
'pytorch-llama2-70b',
PytorchLlama,
parameters='--hidden_size=8192 --num_hidden_layers=80 --num_attention_heads=64 --num_key_value_heads=8 \
--intermediate_size=28672'
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for Llama model benchmarks."""
from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks.pytorch_llama import PytorchLlama
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_llama_7b():
"""Test pytorch-llama2-7b benchmark for fp16 train and inference."""
context = BenchmarkRegistry.create_benchmark_context(
'llama2-7b',
platform=Platform.CUDA,
parameters='--batch_size 1 --seq_len 32 --num_warmup 1 --num_steps 2 --precision float16 \
--model_action train inference',
framework=Framework.PYTORCH
)
assert (BenchmarkRegistry.is_benchmark_context_valid(context))
benchmark = BenchmarkRegistry.launch_benchmark(context)
# Check basic information.
assert (benchmark)
assert (isinstance(benchmark, PytorchLlama))
assert (benchmark.name == 'pytorch-llama2-7b')
assert (benchmark.type == BenchmarkType.MODEL)
# Check predefined parameters of llama2 7b model.
assert (benchmark._args.hidden_size == 4096)
assert (benchmark._args.num_hidden_layers == 32)
assert (benchmark._args.num_attention_heads == 32)
# Check parameters specified in BenchmarkContext.
assert (benchmark._args.batch_size == 1)
assert (benchmark._args.num_classes == 100)
assert (benchmark._args.seq_len == 32)
assert (benchmark._args.num_warmup == 1)
assert (benchmark._args.num_steps == 2)
# Test Dataset.
assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)
# Check results and metrics.
assert (benchmark.run_count == 1)
assert (benchmark.return_code == ReturnCode.SUCCESS)
for metric in [
'fp16_train_step_time', 'fp16_train_throughput', 'fp16_inference_step_time', 'fp16_inference_throughput'
]:
assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
assert (len(benchmark.result[metric]) == benchmark.run_count)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment