Unverified Commit 1be7afd6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename TestWeights to appease pytest (#5054)

parent cca452f0
......@@ -203,7 +203,7 @@ def test_smoke():
# With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface:
class TestWeights(WeightsEnum):
class ModelWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
@pytest.mark.parametrize(
......@@ -211,11 +211,11 @@ class TestHandleLegacyInterface:
[
pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"),
pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
],
)
def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
......@@ -223,7 +223,7 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
......@@ -232,7 +232,7 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
......@@ -242,12 +242,12 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel))
@handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None):
pass
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"):
with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs)
def test_multi_params(self):
......@@ -256,7 +256,7 @@ class TestHandleLegacyInterface:
@handle_legacy_interface(
**{
weights_param: (pretrained_param, self.TestWeights.Sentinel)
weights_param: (pretrained_param, self.ModelWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params)
}
)
......@@ -271,7 +271,7 @@ class TestHandleLegacyInterface:
@handle_legacy_interface(
weights=(
"pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None,
lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
)
)
def builder(*, weights=None, flag):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment