test_context.py 574 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for BenchmarkResult module."""

from superbench.benchmarks import BenchmarkContext, Platform, Framework


def test_benchmark_context():
    """Test BenchmarkContext class."""
11
    context = BenchmarkContext('pytorch-bert-large', Platform.CUDA, '--batch_size 8', framework=Framework.PYTORCH)
12
13
    assert (context.name == 'pytorch-bert-large')
    assert (context.platform == Platform.CUDA)
14
    assert (context.parameters == '--batch_size 8')
15
    assert (context.framework == Framework.PYTORCH)