Unverified Commit 1be7afd6 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

rename TestWeights to appease pytest (#5054)

parent cca452f0
...@@ -203,7 +203,7 @@ def test_smoke(): ...@@ -203,7 +203,7 @@ def test_smoke():
# With this filter, every unexpected warning will be turned into an error # With this filter, every unexpected warning will be turned into an error
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
class TestHandleLegacyInterface: class TestHandleLegacyInterface:
class TestWeights(WeightsEnum): class ModelWeights(WeightsEnum):
Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict()) Sentinel = Weights(url="https://pytorch.org", transforms=lambda x: x, meta=dict())
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -211,11 +211,11 @@ class TestHandleLegacyInterface: ...@@ -211,11 +211,11 @@ class TestHandleLegacyInterface:
[ [
pytest.param(dict(), id="empty"), pytest.param(dict(), id="empty"),
pytest.param(dict(weights=None), id="None"), pytest.param(dict(weights=None), id="None"),
pytest.param(dict(weights=TestWeights.Sentinel), id="Weights"), pytest.param(dict(weights=ModelWeights.Sentinel), id="Weights"),
], ],
) )
def test_no_warn(self, kwargs): def test_no_warn(self, kwargs):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None): def builder(*, weights=None):
pass pass
...@@ -223,7 +223,7 @@ class TestHandleLegacyInterface: ...@@ -223,7 +223,7 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False)) @pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_pos(self, pretrained): def test_pretrained_pos(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None): def builder(*, weights=None):
pass pass
...@@ -232,7 +232,7 @@ class TestHandleLegacyInterface: ...@@ -232,7 +232,7 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False)) @pytest.mark.parametrize("pretrained", (True, False))
def test_pretrained_kw(self, pretrained): def test_pretrained_kw(self, pretrained):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None): def builder(*, weights=None):
pass pass
...@@ -242,12 +242,12 @@ class TestHandleLegacyInterface: ...@@ -242,12 +242,12 @@ class TestHandleLegacyInterface:
@pytest.mark.parametrize("pretrained", (True, False)) @pytest.mark.parametrize("pretrained", (True, False))
@pytest.mark.parametrize("positional", (True, False)) @pytest.mark.parametrize("positional", (True, False))
def test_equivalent_behavior_weights(self, pretrained, positional): def test_equivalent_behavior_weights(self, pretrained, positional):
@handle_legacy_interface(weights=("pretrained", self.TestWeights.Sentinel)) @handle_legacy_interface(weights=("pretrained", self.ModelWeights.Sentinel))
def builder(*, weights=None): def builder(*, weights=None):
pass pass
args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained)) args, kwargs = ((pretrained,), dict()) if positional else ((), dict(pretrained=pretrained))
with pytest.warns(UserWarning, match=f"weights={self.TestWeights.Sentinel if pretrained else None}"): with pytest.warns(UserWarning, match=f"weights={self.ModelWeights.Sentinel if pretrained else None}"):
builder(*args, **kwargs) builder(*args, **kwargs)
def test_multi_params(self): def test_multi_params(self):
...@@ -256,7 +256,7 @@ class TestHandleLegacyInterface: ...@@ -256,7 +256,7 @@ class TestHandleLegacyInterface:
@handle_legacy_interface( @handle_legacy_interface(
**{ **{
weights_param: (pretrained_param, self.TestWeights.Sentinel) weights_param: (pretrained_param, self.ModelWeights.Sentinel)
for weights_param, pretrained_param in zip(weights_params, pretrained_params) for weights_param, pretrained_param in zip(weights_params, pretrained_params)
} }
) )
...@@ -271,7 +271,7 @@ class TestHandleLegacyInterface: ...@@ -271,7 +271,7 @@ class TestHandleLegacyInterface:
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
lambda kwargs: self.TestWeights.Sentinel if kwargs["flag"] else None, lambda kwargs: self.ModelWeights.Sentinel if kwargs["flag"] else None,
) )
) )
def builder(*, weights=None, flag): def builder(*, weights=None, flag):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment