"src/array/cuda/array_op_impl.hip" did not exist on "d3ae7544cda01fa4c4804ab94bf0671b4868bcb5"
Unverified Commit 60c78f28 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

make RandomErasing scriptable for integer value (#7134)

parent 5dd95944
...@@ -672,7 +672,17 @@ def test_autoaugment__op_apply_shear(interpolation, mode): ...@@ -672,7 +672,17 @@ def test_autoaugment__op_apply_shear(interpolation, mode):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config", "config",
[{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}], [
{},
{"value": 1},
{"value": 0.2},
{"value": "random"},
{"value": (1, 1, 1)},
{"value": (0.2, 0.2, 0.2)},
{"value": [1, 1, 1]},
{"value": [0.2, 0.2, 0.2]},
{"value": "random", "ratio": (0.1, 0.2)},
],
) )
def test_random_erasing(device, config): def test_random_erasing(device, config):
tensor, _ = _create_data(24, 32, channels=3, device=device) tensor, _ = _create_data(24, 32, channels=3, device=device)
......
...@@ -23,7 +23,7 @@ class RandomErasing(_RandomApplyTransform): ...@@ -23,7 +23,7 @@ class RandomErasing(_RandomApplyTransform):
p: float = 0.5, p: float = 0.5,
scale: Tuple[float, float] = (0.02, 0.33), scale: Tuple[float, float] = (0.02, 0.33),
ratio: Tuple[float, float] = (0.3, 3.3), ratio: Tuple[float, float] = (0.3, 3.3),
value: float = 0, value: float = 0.0,
inplace: bool = False, inplace: bool = False,
): ):
super().__init__(p=p) super().__init__(p=p)
...@@ -42,11 +42,11 @@ class RandomErasing(_RandomApplyTransform): ...@@ -42,11 +42,11 @@ class RandomErasing(_RandomApplyTransform):
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
self.value = [value] self.value = [float(value)]
elif isinstance(value, str): elif isinstance(value, str):
self.value = None self.value = None
elif isinstance(value, tuple): elif isinstance(value, (list, tuple)):
self.value = list(value) self.value = [float(v) for v in value]
else: else:
self.value = value self.value = value
self.inplace = inplace self.inplace = inplace
......
...@@ -1713,11 +1713,11 @@ class RandomErasing(torch.nn.Module): ...@@ -1713,11 +1713,11 @@ class RandomErasing(torch.nn.Module):
# cast self.value to script acceptable type # cast self.value to script acceptable type
if isinstance(self.value, (int, float)): if isinstance(self.value, (int, float)):
value = [self.value] value = [float(self.value)]
elif isinstance(self.value, str): elif isinstance(self.value, str):
value = None value = None
elif isinstance(self.value, tuple): elif isinstance(self.value, (list, tuple)):
value = list(self.value) value = [float(v) for v in self.value]
else: else:
value = self.value value = self.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment