test_pytorch_bert.py 2.31 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for BERT model benchmarks."""

6
from tests.helper import decorator
7
8
from superbench.benchmarks import BenchmarkRegistry, Platform, Framework, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks.pytorch_bert import PytorchBERT
9
10


11
12
@decorator.cuda_test
@decorator.pytorch_test
13
14
def test_pytorch_bert_base():
    """Test pytorch-bert-base benchmark."""
15
    context = BenchmarkRegistry.create_benchmark_context(
16
        'bert-base',
17
        platform=Platform.CUDA,
18
19
        parameters='--batch_size 1 --num_classes 5 --seq_len 8 --num_warmup 2 --num_steps 4 \
            --model_action train inference',
20
21
22
23
24
        framework=Framework.PYTORCH
    )

    assert (BenchmarkRegistry.is_benchmark_context_valid(context))

25
    benchmark = BenchmarkRegistry.launch_benchmark(context)
26

27
28
29
30
31
    # Check basic information.
    assert (benchmark)
    assert (isinstance(benchmark, PytorchBERT))
    assert (benchmark.name == 'pytorch-bert-base')
    assert (benchmark.type == BenchmarkType.MODEL)
32

33
    # Check predefined parameters of resnet101 model.
34
35
36
37
38
    assert (benchmark._args.hidden_size == 768)
    assert (benchmark._args.num_hidden_layers == 12)
    assert (benchmark._args.num_attention_heads == 12)
    assert (benchmark._args.intermediate_size == 3072)

39
40
    # Check parameters specified in BenchmarkContext.
    assert (benchmark._args.batch_size == 1)
41
    assert (benchmark._args.num_classes == 5)
42
43
44
    assert (benchmark._args.seq_len == 8)
    assert (benchmark._args.num_warmup == 2)
    assert (benchmark._args.num_steps == 4)
45

46
    # Check dataset scale.
47
48
    assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)

49
50
51
52
53
54
55
56
57
58
59
    # Check results and metrics.
    assert (benchmark.run_count == 1)
    assert (benchmark.return_code == ReturnCode.SUCCESS)
    for metric in [
        'steptime_train_float32', 'throughput_train_float32', 'steptime_train_float16', 'throughput_train_float16',
        'steptime_inference_float32', 'throughput_inference_float32', 'steptime_inference_float16',
        'throughput_inference_float16'
    ]:
        assert (len(benchmark.raw_data[metric]) == benchmark.run_count)
        assert (len(benchmark.raw_data[metric][0]) == benchmark._args.num_steps)
        assert (len(benchmark.result[metric]) == benchmark.run_count)