"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "da113364dff2f52cc26690d617d766272ed643ca"
Unverified Commit b83d5f7c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add support for apply probability to CutMix and MixUp (#6448)

parent 2a0eea82
...@@ -6,7 +6,7 @@ from typing import Any, Dict, Tuple ...@@ -6,7 +6,7 @@ from typing import Any, Dict, Tuple
import PIL.Image import PIL.Image
import torch import torch
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform from torchvision.prototype.transforms import functional as F
from ._transform import _RandomApplyTransform from ._transform import _RandomApplyTransform
from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image
...@@ -97,9 +97,9 @@ class RandomErasing(_RandomApplyTransform): ...@@ -97,9 +97,9 @@ class RandomErasing(_RandomApplyTransform):
return inpt return inpt
class _BaseMixupCutmix(Transform): class _BaseMixupCutmix(_RandomApplyTransform):
def __init__(self, *, alpha: float) -> None: def __init__(self, *, alpha: float, p: float = 0.5) -> None:
super().__init__() super().__init__(p=p)
self.alpha = alpha self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment