Unverified Commit c4c0ef98 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make weights deepcopyable (#6883)

* make weights deepcopyable

* add test

* test enum member instead of whole enum
parent d95fbaf1
import copy
import os
import pytest
......@@ -59,6 +60,25 @@ def test_get_model_weights(name, weight):
assert models.get_model_weights(name) == weight
@pytest.mark.parametrize("copy_fn", [copy.copy, copy.deepcopy])
@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_copyable(copy_fn, name):
model_weights = models.get_model_weights(name)
for weights in list(model_weights):
copied_weights = copy_fn(weights)
assert copied_weights is weights
@pytest.mark.parametrize(
"module", [models, models.detection, models.quantization, models.segmentation, models.video, models.optical_flow]
)
......
......@@ -75,6 +75,9 @@ class WeightsEnum(StrEnum):
return object.__getattribute__(self.value, name)
return super().__getattr__(name)
def __deepcopy__(self, memodict=None):
return self
def get_weight(name: str) -> WeightsEnum:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment