Unverified Commit 93df9a50 authored by RoiEX's avatar RoiEX Committed by GitHub
Browse files

Add missing type hints to ColorJitter constructor (#7087)

parent 8985b598
...@@ -3,7 +3,7 @@ import numbers ...@@ -3,7 +3,7 @@ import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -1172,7 +1172,13 @@ class ColorJitter(torch.nn.Module): ...@@ -1172,7 +1172,13 @@ class ColorJitter(torch.nn.Module):
or use an interpolation that generates negative values before using this function. or use an interpolation that generates negative values before using this function.
""" """
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0,
contrast: Union[float, Tuple[float, float]] = 0,
saturation: Union[float, Tuple[float, float]] = 0,
hue: Union[float, Tuple[float, float]] = 0,
) -> None:
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
self.brightness = self._check_input(brightness, "brightness") self.brightness = self._check_input(brightness, "brightness")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment