Unverified Commit 1dc0318f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add docs for containers and undeprecate p for RandomChoice (#7311)


Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
parent 14c003bd
......@@ -1359,11 +1359,8 @@ class TestContainers:
class TestRandomChoice:
def test_assertions(self):
with pytest.warns(UserWarning, match="Argument p is deprecated and will be removed"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1, 2])
with pytest.raises(ValueError, match="The number of probabilities doesn't match the number of transforms"):
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], probabilities=[1])
transforms.RandomChoice([transforms.Pad(2), transforms.RandomCrop(28)], p=[1])
class TestRandomIoUCrop:
......
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import torch
......@@ -78,7 +77,7 @@ class RandomApply(Transform):
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (float): probability
p (float): probability of applying the list of transforms
"""
_v1_transform_cls = _transforms.RandomApply
......@@ -119,39 +118,38 @@ class RandomChoice(Transform):
.. betastatus:: RandomChoice transform
This transform does not support torchscript."""
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
"""
def __init__(
self,
transforms: Sequence[Callable],
p: Optional[List[float]] = None,
probabilities: Optional[List[float]] = None,
) -> None:
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is not None:
warnings.warn(
"Argument p is deprecated and will be removed in a future release. "
"Please use probabilities argument instead."
)
probabilities = p
if probabilities is None:
probabilities = [1] * len(transforms)
elif len(probabilities) != len(transforms):
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"The number of probabilities doesn't match the number of transforms: "
f"{len(probabilities)} != {len(transforms)}"
f"The number of p doesn't match the number of transforms: " f"{len(p)} != {len(transforms)}"
)
super().__init__()
self.transforms = transforms
total = sum(probabilities)
self.probabilities = [prob / total for prob in probabilities]
total = sum(p)
self.p = [prob / total for prob in p]
def forward(self, *inputs: Any) -> Any:
idx = int(torch.multinomial(torch.tensor(self.probabilities), 1))
idx = int(torch.multinomial(torch.tensor(self.p), 1))
transform = self.transforms[idx]
return transform(*inputs)
......@@ -162,6 +160,9 @@ class RandomOrder(Transform):
.. betastatus:: RandomOrder transform
This transform does not support torchscript.
Args:
transforms (sequence or torch.nn.Module): list of transformations
"""
def __init__(self, transforms: Sequence[Callable]) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment