Unverified Commit 7c0534cc authored by Yifan Xiong's avatar Yifan Xiong Committed by GitHub
Browse files

Skip tests and remove useless tests (#42)

* skip unnecessary tests according to env var
* remove useless tests
parent 7bd41649
......@@ -8,6 +8,7 @@
import torch
from tests.helper import decorator
from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision, ReturnCode
from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, DistributedBackend
......@@ -169,6 +170,7 @@ def _inference_step(self, precision):
return duration
@decorator.pytorch_test
def test_pytorch_base():
"""Test PytorchBase class."""
# Register BERT Base benchmark.
......
......@@ -3,10 +3,13 @@
"""Tests for BERT model benchmarks."""
from tests.helper import decorator
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, Framework
import superbench.benchmarks.model_benchmarks.pytorch_bert as pybert
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_bert_base():
"""Test pytorch-bert-base benchmark."""
context = BenchmarkRegistry.create_benchmark_context(
......@@ -51,6 +54,8 @@ def test_pytorch_bert_base():
assert (isinstance(benchmark._model, pybert.BertBenchmarkModel))
@decorator.cuda_test
@decorator.pytorch_test
def test_pytorch_bert_large():
"""Test pytorch-bert-large benchmark."""
context = BenchmarkRegistry.create_benchmark_context(
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Unittest decorator helpers."""
import os
import unittest
cuda_test = unittest.skipIf(os.environ.get('SB_TEST_CUDA', '1') == '0', 'Skip CUDA tests.')
rocm_test = unittest.skipIf(os.environ.get('SB_TEST_ROCM', '0') == '0', 'Skip ROCm tests.')
pytorch_test = unittest.skipIf(os.environ.get('SB_TEST_PYTORCH', '1') == '0', 'Skip PyTorch tests.')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Test example.
Get it from https://docs.pytest.org/en/stable/.
"""
import superbench
def inc(x):
"""Increase an integer.
Args:
x (int): Input value.
Returns:
int: Increased value.
"""
return x + 1
def test_answer():
"""Test inc function."""
assert inc(3) == 4
def test_superbench():
"""Test SuperBench."""
assert (superbench.__version__ == '0.0.0')
assert (superbench.__author__ == 'Microsoft')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment