"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "0785dba4df988119955b5380877e50d134416101"
Unverified Commit 74feb198 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Skip big models per platform/device (#6539)

* Skip big models per platform/device

* Specifying skips on Windows only.

* Simplify and clean up code.
parent 9b432d07
......@@ -3,6 +3,7 @@ import functools
import operator
import os
import pkgutil
import platform
import sys
import warnings
from collections import OrderedDict
......@@ -343,12 +344,25 @@ for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
# skip big models to reduce memory usage on CI test
# skip big models to reduce memory usage on CI test. We can exclude combinations of (platform-system, device).
skipped_big_models = {
"vit_h_14",
"regnet_y_128gf",
"vit_h_14": {("Windows", "cpu"), ("Windows", "cuda")},
"regnet_y_128gf": {("Windows", "cpu"), ("Windows", "cuda")},
"mvit_v1_b": {("Windows", "cuda")},
"mvit_v2_s": {("Windows", "cuda")},
}
def is_skippable(model_name, device):
if model_name not in skipped_big_models:
return False
platform_system = platform.system()
device_name = str(device).split(":")[0]
return (platform_system, device_name) in skipped_big_models[model_name]
# The following contains configuration and expected values to be used tests that are model specific
_model_tests_values = {
"retinanet_resnet50_fpn": {
......@@ -612,7 +626,7 @@ def test_classification_model(model_fn, dev):
"input_shape": (1, 3, 224, 224),
}
model_name = model_fn.__name__
if SKIP_BIG_MODEL and model_name in skipped_big_models:
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
......@@ -841,7 +855,7 @@ def test_video_model(model_fn, dev):
"num_classes": 50,
}
model_name = model_fn.__name__
if SKIP_BIG_MODEL and model_name in skipped_big_models:
if SKIP_BIG_MODEL and is_skippable(model_name, dev):
pytest.skip("Skipped to reduce memory usage. Set env var SKIP_BIG_MODEL=0 to enable test for this model")
kwargs = {**defaults, **_model_params.get(model_name, {})}
num_classes = kwargs.get("num_classes")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment