Unverified Commit ad3c3f78 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding docs for RandAugment (#4349)

* Adding docs for RandAugment.

* Fix docs.
parent 5a815541
...@@ -214,8 +214,8 @@ Generic Transforms ...@@ -214,8 +214,8 @@ Generic Transforms
:members: :members:
AutoAugment Transforms Automatic Augmentation Transforms
---------------------- ---------------------------------
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models. `AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
...@@ -229,6 +229,10 @@ The new transform can be used standalone or mixed-and-matched with existing tran ...@@ -229,6 +229,10 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: AutoAugment .. autoclass:: AutoAugment
:members: :members:
`RandAugment <https://arxiv.org/abs/1909.13719>`_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.
.. autoclass:: RandAugment
:members:
.. _functional_transforms: .. _functional_transforms:
......
...@@ -245,6 +245,14 @@ imgs = [ ...@@ -245,6 +245,14 @@ imgs = [
row_title = [str(policy).split('.')[-1] for policy in policies] row_title = [str(policy).split('.')[-1] for policy in policies]
plot(imgs, row_title=row_title) plot(imgs, row_title=row_title)
####################################
# RandAugment
# ~~~~~~~~~~~
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
augmenter = T.RandAugment()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)
#################################### ####################################
# Randomly-applied transforms # Randomly-applied transforms
# --------------------------- # ---------------------------
......
...@@ -245,7 +245,7 @@ class AutoAugment(torch.nn.Module): ...@@ -245,7 +245,7 @@ class AutoAugment(torch.nn.Module):
class RandAugment(torch.nn.Module): class RandAugment(torch.nn.Module):
r"""RandAugment data augmentation method based on r"""RandAugment data augmentation method based on
`"RandAugment: Practical automated data augmentation with a reduced search space" `"RandAugment: Practical automated data augmentation with a reduced search space"
<https://arxiv.org/abs/1909.13719>`. <https://arxiv.org/abs/1909.13719>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
...@@ -293,6 +293,7 @@ class RandAugment(torch.nn.Module): ...@@ -293,6 +293,7 @@ class RandAugment(torch.nn.Module):
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
""" """
img (PIL Image or Tensor): Image to be transformed. img (PIL Image or Tensor): Image to be transformed.
Returns: Returns:
PIL Image or Tensor: Transformed image. PIL Image or Tensor: Transformed image.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment