Unverified Commit 60c78f28 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make RandomErasing scriptable for integer value (#7134)

parent 5dd95944
......@@ -672,7 +672,17 @@ def test_autoaugment__op_apply_shear(interpolation, mode):
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"config",
[{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}],
[
{},
{"value": 1},
{"value": 0.2},
{"value": "random"},
{"value": (1, 1, 1)},
{"value": (0.2, 0.2, 0.2)},
{"value": [1, 1, 1]},
{"value": [0.2, 0.2, 0.2]},
{"value": "random", "ratio": (0.1, 0.2)},
],
)
def test_random_erasing(device, config):
tensor, _ = _create_data(24, 32, channels=3, device=device)
......
......@@ -23,7 +23,7 @@ class RandomErasing(_RandomApplyTransform):
p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0,
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
......@@ -42,11 +42,11 @@ class RandomErasing(_RandomApplyTransform):
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [value]
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, tuple):
self.value = list(value)
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
......
......@@ -1713,11 +1713,11 @@ class RandomErasing(torch.nn.Module):
# cast self.value to script acceptable type
if isinstance(self.value, (int, float)):
value = [self.value]
value = [float(self.value)]
elif isinstance(self.value, str):
value = None
elif isinstance(self.value, tuple):
value = list(self.value)
elif isinstance(self.value, (list, tuple)):
value = [float(v) for v in self.value]
else:
value = self.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment