You need to sign in or sign up before continuing.
test_pytorch_bert.py 3.45 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for BERT model benchmarks."""

6
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, Framework
7
8
9
10
11
import superbench.benchmarks.model_benchmarks.pytorch_bert as pybert


def test_pytorch_bert_base():
    """Test pytorch-bert-base benchmark."""
12
    context = BenchmarkRegistry.create_benchmark_context(
13
        'bert-base',
14
        platform=Platform.CUDA,
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        parameters='--batch_size=32 --num_classes=5 --seq_len=512',
        framework=Framework.PYTORCH
    )

    assert (BenchmarkRegistry.is_benchmark_context_valid(context))

    benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
    assert (benchmark_name == 'pytorch-bert-base')

    (benchmark_class,
     predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, context.platform)
    assert (benchmark_class == pybert.PytorchBERT)

    parameters = context.parameters
    if predefine_params:
        parameters = predefine_params + ' ' + parameters

    benchmark = benchmark_class(benchmark_name, parameters)
    assert (benchmark._preprocess() is True)

    # Predefined parameters of bert-base model.
    assert (benchmark._args.hidden_size == 768)
    assert (benchmark._args.num_hidden_layers == 12)
    assert (benchmark._args.num_attention_heads == 12)
    assert (benchmark._args.intermediate_size == 3072)

    # Parameters from BenchmarkContext.
    assert (benchmark._args.batch_size == 32)
    assert (benchmark._args.num_classes == 5)
    assert (benchmark._args.seq_len == 512)

    # Test Dataset.
    assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)

    # Test _create_model().
    assert (benchmark._create_model(Precision.FLOAT32) is True)
    assert (isinstance(benchmark._model, pybert.BertBenchmarkModel))


def test_pytorch_bert_large():
    """Test pytorch-bert-large benchmark."""
56
    context = BenchmarkRegistry.create_benchmark_context(
57
        'bert-large',
58
        platform=Platform.CUDA,
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        parameters='--batch_size=32 --num_classes=5 --seq_len=512',
        framework=Framework.PYTORCH
    )

    assert (BenchmarkRegistry.is_benchmark_context_valid(context))

    benchmark_name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
    assert (benchmark_name == 'pytorch-bert-large')

    (benchmark_class,
     predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, context.platform)
    assert (benchmark_class is pybert.PytorchBERT)

    parameters = context.parameters
    if predefine_params:
        parameters = predefine_params + ' ' + parameters

    benchmark = benchmark_class(benchmark_name, parameters)
    assert (benchmark._preprocess() is True)

    # Predefined parameters of bert-large model.
    assert (benchmark._args.hidden_size == 1024)
    assert (benchmark._args.num_hidden_layers == 24)
    assert (benchmark._args.num_attention_heads == 16)
    assert (benchmark._args.intermediate_size == 4096)

    # Parameters from BenchmarkContext.
    assert (benchmark._args.batch_size == 32)
    assert (benchmark._args.num_classes == 5)
    assert (benchmark._args.seq_len == 512)

    # Test Dataset.
    assert (len(benchmark._dataset) == benchmark._args.sample_count * benchmark._world_size)

    # Test _create_model().
    assert (benchmark._create_model(Precision.FLOAT32) is True)
    assert (isinstance(benchmark._model, pybert.BertBenchmarkModel))