Unverified Commit 93c85bbc authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

Consolidate repr (#5392)



* Consolidating __repr__ strings
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent c39c23ed
...@@ -72,13 +72,15 @@ class RandomMixup(torch.nn.Module): ...@@ -72,13 +72,15 @@ class RandomMixup(torch.nn.Module):
return batch, target return batch, target
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "num_classes={num_classes}" f"{self.__class__.__name__}("
s += ", p={p}" f"num_classes={self.num_classes}"
s += ", alpha={alpha}" f", p={self.p}"
s += ", inplace={inplace}" f", alpha={self.alpha}"
s += ")" f", inplace={self.inplace}"
return s.format(**self.__dict__) f")"
)
return s
class RandomCutmix(torch.nn.Module): class RandomCutmix(torch.nn.Module):
...@@ -162,10 +164,12 @@ class RandomCutmix(torch.nn.Module): ...@@ -162,10 +164,12 @@ class RandomCutmix(torch.nn.Module):
return batch, target return batch, target
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "num_classes={num_classes}" f"{self.__class__.__name__}("
s += ", p={p}" f"num_classes={self.num_classes}"
s += ", alpha={alpha}" f", p={self.p}"
s += ", inplace={inplace}" f", alpha={self.alpha}"
s += ")" f", inplace={self.inplace}"
return s.format(**self.__dict__) f")"
)
return s
...@@ -180,7 +180,7 @@ class DownloadConfig: ...@@ -180,7 +180,7 @@ class DownloadConfig:
self.md5 = md5 self.md5 = md5
self.id = id or url self.id = id or url
def __repr__(self): def __repr__(self) -> str:
return self.id return self.id
......
...@@ -239,13 +239,15 @@ class DefaultBoxGenerator(nn.Module): ...@@ -239,13 +239,15 @@ class DefaultBoxGenerator(nn.Module):
return torch.cat(default_boxes, dim=0) return torch.cat(default_boxes, dim=0)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "aspect_ratios={aspect_ratios}" f"{self.__class__.__name__}("
s += ", clip={clip}" f"aspect_ratios={self.aspect_ratios}"
s += ", scales={scales}" f", clip={self.clip}"
s += ", steps={steps}" f", scales={self.scales}"
s += ")" f", steps={self.steps}"
return s.format(**self.__dict__) ")"
)
return s
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
......
...@@ -260,7 +260,7 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -260,7 +260,7 @@ class GeneralizedRCNNTransform(nn.Module):
return result return result
def __repr__(self) -> str: def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" format_string = f"{self.__class__.__name__}("
_indent = "\n " _indent = "\n "
format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})" format_string += f"{_indent}Normalize(mean={self.image_mean}, std={self.image_std})"
format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')" format_string += f"{_indent}Resize(min_size={self.min_size}, max_size={self.max_size}, mode='bilinear')"
......
...@@ -61,15 +61,17 @@ class MBConvConfig: ...@@ -61,15 +61,17 @@ class MBConvConfig:
self.num_layers = self.adjust_depth(num_layers, depth_mult) self.num_layers = self.adjust_depth(num_layers, depth_mult)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "expand_ratio={expand_ratio}" f"{self.__class__.__name__}("
s += ", kernel={kernel}" f"expand_ratio={self.expand_ratio}"
s += ", stride={stride}" f", kernel={self.kernel}"
s += ", input_channels={input_channels}" f", stride={self.stride}"
s += ", out_channels={out_channels}" f", input_channels={self.input_channels}"
s += ", num_layers={num_layers}" f", out_channels={self.out_channels}"
s += ")" f", num_layers={self.num_layers}"
return s.format(**self.__dict__) f")"
)
return s
@staticmethod @staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
......
...@@ -179,14 +179,17 @@ class DeformConv2d(nn.Module): ...@@ -179,14 +179,17 @@ class DeformConv2d(nn.Module):
) )
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "{in_channels}" f"{self.__class__.__name__}("
s += ", {out_channels}" f"{self.in_channels}"
s += ", kernel_size={kernel_size}" f", {self.out_channels}"
s += ", stride={stride}" f", kernel_size={self.kernel_size}"
s += ", padding={padding}" if self.padding != (0, 0) else "" f", stride={self.stride}"
s += ", dilation={dilation}" if self.dilation != (1, 1) else "" )
s += ", groups={groups}" if self.groups != 1 else "" s += f", padding={self.padding}" if self.padding != (0, 0) else ""
s += f", dilation={self.dilation}" if self.dilation != (1, 1) else ""
s += f", groups={self.groups}" if self.groups != 1 else ""
s += ", bias=False" if self.bias is None else "" s += ", bias=False" if self.bias is None else ""
s += ")" s += ")"
return s.format(**self.__dict__)
return s
...@@ -78,9 +78,11 @@ class PSRoIAlign(nn.Module): ...@@ -78,9 +78,11 @@ class PSRoIAlign(nn.Module):
return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + "(" s = (
tmpstr += "output_size=" + str(self.output_size) f"{self.__class__.__name__}("
tmpstr += ", spatial_scale=" + str(self.spatial_scale) f"output_size={self.output_size}"
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) f", spatial_scale={self.spatial_scale}"
tmpstr += ")" f", sampling_ratio={self.sampling_ratio}"
return tmpstr f")"
)
return s
...@@ -64,8 +64,5 @@ class PSRoIPool(nn.Module): ...@@ -64,8 +64,5 @@ class PSRoIPool(nn.Module):
return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) return ps_roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + "(" s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
tmpstr += "output_size=" + str(self.output_size) return s
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ")"
return tmpstr
...@@ -86,10 +86,12 @@ class RoIAlign(nn.Module): ...@@ -86,10 +86,12 @@ class RoIAlign(nn.Module):
return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + "(" s = (
tmpstr += "output_size=" + str(self.output_size) f"{self.__class__.__name__}("
tmpstr += ", spatial_scale=" + str(self.spatial_scale) f"output_size={self.output_size}"
tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) f", spatial_scale={self.spatial_scale}"
tmpstr += ", aligned=" + str(self.aligned) f", sampling_ratio={self.sampling_ratio}"
tmpstr += ")" f", aligned={self.aligned}"
return tmpstr f")"
)
return s
...@@ -66,8 +66,5 @@ class RoIPool(nn.Module): ...@@ -66,8 +66,5 @@ class RoIPool(nn.Module):
return roi_pool(input, rois, self.output_size, self.spatial_scale) return roi_pool(input, rois, self.output_size, self.spatial_scale)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + "(" s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
tmpstr += "output_size=" + str(self.output_size) return s
tmpstr += ", spatial_scale=" + str(self.spatial_scale)
tmpstr += ")"
return tmpstr
...@@ -62,8 +62,5 @@ class StochasticDepth(nn.Module): ...@@ -62,8 +62,5 @@ class StochasticDepth(nn.Module):
return stochastic_depth(input, self.p, self.mode, self.training) return stochastic_depth(input, self.p, self.mode, self.training)
def __repr__(self) -> str: def __repr__(self) -> str:
tmpstr = self.__class__.__name__ + "(" s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
tmpstr += "p=" + str(self.p) return s
tmpstr += ", mode=" + str(self.mode)
tmpstr += ")"
return tmpstr
...@@ -96,5 +96,5 @@ class Feature(torch.Tensor): ...@@ -96,5 +96,5 @@ class Feature(torch.Tensor):
return cls(output, like=args[0]) return cls(output, like=args[0])
def __repr__(self): def __repr__(self) -> str:
return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__)
...@@ -67,7 +67,7 @@ class WeightsEnum(Enum): ...@@ -67,7 +67,7 @@ class WeightsEnum(Enum):
def get_state_dict(self, progress: bool) -> OrderedDict: def get_state_dict(self, progress: bool) -> OrderedDict:
return load_state_dict_from_url(self.url, progress=progress) return load_state_dict_from_url(self.url, progress=progress)
def __repr__(self): def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}" return f"{self.__class__.__name__}.{self._name_}"
def __getattr__(self, name): def __getattr__(self, name):
......
...@@ -46,8 +46,8 @@ class RandomCropVideo(RandomCrop): ...@@ -46,8 +46,8 @@ class RandomCropVideo(RandomCrop):
i, j, h, w = self.get_params(clip, self.size) i, j, h, w = self.get_params(clip, self.size)
return F.crop(clip, i, j, h, w) return F.crop(clip, i, j, h, w)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(size={self.size})" return f"{self.__class__.__name__}(size={self.size})"
class RandomResizedCropVideo(RandomResizedCrop): class RandomResizedCropVideo(RandomResizedCrop):
...@@ -79,11 +79,8 @@ class RandomResizedCropVideo(RandomResizedCrop): ...@@ -79,11 +79,8 @@ class RandomResizedCropVideo(RandomResizedCrop):
i, j, h, w = self.get_params(clip, self.scale, self.ratio) i, j, h, w = self.get_params(clip, self.scale, self.ratio)
return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode)
def __repr__(self): def __repr__(self) -> str:
return ( return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
self.__class__.__name__
+ f"(size={self.size}, interpolation_mode={self.interpolation_mode}, scale={self.scale}, ratio={self.ratio})"
)
class CenterCropVideo: class CenterCropVideo:
...@@ -103,8 +100,8 @@ class CenterCropVideo: ...@@ -103,8 +100,8 @@ class CenterCropVideo:
""" """
return F.center_crop(clip, self.crop_size) return F.center_crop(clip, self.crop_size)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(crop_size={self.crop_size})" return f"{self.__class__.__name__}(crop_size={self.crop_size})"
class NormalizeVideo: class NormalizeVideo:
...@@ -128,8 +125,8 @@ class NormalizeVideo: ...@@ -128,8 +125,8 @@ class NormalizeVideo:
""" """
return F.normalize(clip, self.mean, self.std, self.inplace) return F.normalize(clip, self.mean, self.std, self.inplace)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(mean={self.mean}, std={self.std}, inplace={self.inplace})" return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
class ToTensorVideo: class ToTensorVideo:
...@@ -150,7 +147,7 @@ class ToTensorVideo: ...@@ -150,7 +147,7 @@ class ToTensorVideo:
""" """
return F.to_tensor(clip) return F.to_tensor(clip)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ return self.__class__.__name__
...@@ -175,5 +172,5 @@ class RandomHorizontalFlipVideo: ...@@ -175,5 +172,5 @@ class RandomHorizontalFlipVideo:
clip = F.hflip(clip) clip = F.hflip(clip)
return clip return clip
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
...@@ -280,7 +280,7 @@ class AutoAugment(torch.nn.Module): ...@@ -280,7 +280,7 @@ class AutoAugment(torch.nn.Module):
return img return img
def __repr__(self) -> str: def __repr__(self) -> str:
return self.__class__.__name__ + f"(policy={self.policy}, fill={self.fill})" return f"{self.__class__.__name__}(policy={self.policy}, fill={self.fill})"
class RandAugment(torch.nn.Module): class RandAugment(torch.nn.Module):
...@@ -363,14 +363,16 @@ class RandAugment(torch.nn.Module): ...@@ -363,14 +363,16 @@ class RandAugment(torch.nn.Module):
return img return img
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "num_ops={num_ops}" f"{self.__class__.__name__}("
s += ", magnitude={magnitude}" f"num_ops={self.num_ops}"
s += ", num_magnitude_bins={num_magnitude_bins}" f", magnitude={self.magnitude}"
s += ", interpolation={interpolation}" f", num_magnitude_bins={self.num_magnitude_bins}"
s += ", fill={fill}" f", interpolation={self.interpolation}"
s += ")" f", fill={self.fill}"
return s.format(**self.__dict__) f")"
)
return s
class TrivialAugmentWide(torch.nn.Module): class TrivialAugmentWide(torch.nn.Module):
...@@ -448,9 +450,11 @@ class TrivialAugmentWide(torch.nn.Module): ...@@ -448,9 +450,11 @@ class TrivialAugmentWide(torch.nn.Module):
return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + "(" s = (
s += "num_magnitude_bins={num_magnitude_bins}" f"{self.__class__.__name__}("
s += ", interpolation={interpolation}" f"num_magnitude_bins={self.num_magnitude_bins}"
s += ", fill={fill}" f", interpolation={self.interpolation}"
s += ")" f", fill={self.fill}"
return s.format(**self.__dict__) f")"
)
return s
...@@ -95,7 +95,7 @@ class Compose: ...@@ -95,7 +95,7 @@ class Compose:
img = t(img) img = t(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" format_string = self.__class__.__name__ + "("
for t in self.transforms: for t in self.transforms:
format_string += "\n" format_string += "\n"
...@@ -134,8 +134,8 @@ class ToTensor: ...@@ -134,8 +134,8 @@ class ToTensor:
""" """
return F.to_tensor(pic) return F.to_tensor(pic)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + "()" return f"{self.__class__.__name__}()"
class PILToTensor: class PILToTensor:
...@@ -161,8 +161,8 @@ class PILToTensor: ...@@ -161,8 +161,8 @@ class PILToTensor:
""" """
return F.pil_to_tensor(pic) return F.pil_to_tensor(pic)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + "()" return f"{self.__class__.__name__}()"
class ConvertImageDtype(torch.nn.Module): class ConvertImageDtype(torch.nn.Module):
...@@ -226,7 +226,7 @@ class ToPILImage: ...@@ -226,7 +226,7 @@ class ToPILImage:
""" """
return F.to_pil_image(pic, self.mode) return F.to_pil_image(pic, self.mode)
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" format_string = self.__class__.__name__ + "("
if self.mode is not None: if self.mode is not None:
format_string += f"mode={self.mode}" format_string += f"mode={self.mode}"
...@@ -269,8 +269,8 @@ class Normalize(torch.nn.Module): ...@@ -269,8 +269,8 @@ class Normalize(torch.nn.Module):
""" """
return F.normalize(tensor, self.mean, self.std, self.inplace) return F.normalize(tensor, self.mean, self.std, self.inplace)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(mean={self.mean}, std={self.std})" return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"
class Resize(torch.nn.Module): class Resize(torch.nn.Module):
...@@ -348,9 +348,9 @@ class Resize(torch.nn.Module): ...@@ -348,9 +348,9 @@ class Resize(torch.nn.Module):
""" """
return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
def __repr__(self): def __repr__(self) -> str:
detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})" detail = f"(size={self.size}, interpolation={self.interpolation.value}, max_size={self.max_size}, antialias={self.antialias})"
return self.__class__.__name__ + detail return f"{self.__class__.__name__}{detail}"
class CenterCrop(torch.nn.Module): class CenterCrop(torch.nn.Module):
...@@ -380,8 +380,8 @@ class CenterCrop(torch.nn.Module): ...@@ -380,8 +380,8 @@ class CenterCrop(torch.nn.Module):
""" """
return F.center_crop(img, self.size) return F.center_crop(img, self.size)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(size={self.size})" return f"{self.__class__.__name__}(size={self.size})"
class Pad(torch.nn.Module): class Pad(torch.nn.Module):
...@@ -453,8 +453,8 @@ class Pad(torch.nn.Module): ...@@ -453,8 +453,8 @@ class Pad(torch.nn.Module):
""" """
return F.pad(img, self.padding, self.fill, self.padding_mode) return F.pad(img, self.padding, self.fill, self.padding_mode)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})" return f"{self.__class__.__name__}(padding={self.padding}, fill={self.fill}, padding_mode={self.padding_mode})"
class Lambda: class Lambda:
...@@ -473,8 +473,8 @@ class Lambda: ...@@ -473,8 +473,8 @@ class Lambda:
def __call__(self, img): def __call__(self, img):
return self.lambd(img) return self.lambd(img)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + "()" return f"{self.__class__.__name__}()"
class RandomTransforms: class RandomTransforms:
...@@ -493,7 +493,7 @@ class RandomTransforms: ...@@ -493,7 +493,7 @@ class RandomTransforms:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" format_string = self.__class__.__name__ + "("
for t in self.transforms: for t in self.transforms:
format_string += "\n" format_string += "\n"
...@@ -535,7 +535,7 @@ class RandomApply(torch.nn.Module): ...@@ -535,7 +535,7 @@ class RandomApply(torch.nn.Module):
img = t(img) img = t(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" format_string = self.__class__.__name__ + "("
format_string += f"\n p={self.p}" format_string += f"\n p={self.p}"
for t in self.transforms: for t in self.transforms:
...@@ -569,10 +569,8 @@ class RandomChoice(RandomTransforms): ...@@ -569,10 +569,8 @@ class RandomChoice(RandomTransforms):
t = random.choices(self.transforms, weights=self.p)[0] t = random.choices(self.transforms, weights=self.p)[0]
return t(*args) return t(*args)
def __repr__(self): def __repr__(self) -> str:
format_string = super().__repr__() return f"{super().__repr__()}(p={self.p})"
format_string += f"(p={self.p})"
return format_string
class RandomCrop(torch.nn.Module): class RandomCrop(torch.nn.Module):
...@@ -679,8 +677,8 @@ class RandomCrop(torch.nn.Module): ...@@ -679,8 +677,8 @@ class RandomCrop(torch.nn.Module):
return F.crop(img, i, j, h, w) return F.crop(img, i, j, h, w)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(size={self.size}, padding={self.padding})" return f"{self.__class__.__name__}(size={self.size}, padding={self.padding})"
class RandomHorizontalFlip(torch.nn.Module): class RandomHorizontalFlip(torch.nn.Module):
...@@ -710,8 +708,8 @@ class RandomHorizontalFlip(torch.nn.Module): ...@@ -710,8 +708,8 @@ class RandomHorizontalFlip(torch.nn.Module):
return F.hflip(img) return F.hflip(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomVerticalFlip(torch.nn.Module): class RandomVerticalFlip(torch.nn.Module):
...@@ -741,8 +739,8 @@ class RandomVerticalFlip(torch.nn.Module): ...@@ -741,8 +739,8 @@ class RandomVerticalFlip(torch.nn.Module):
return F.vflip(img) return F.vflip(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomPerspective(torch.nn.Module): class RandomPerspective(torch.nn.Module):
...@@ -842,8 +840,8 @@ class RandomPerspective(torch.nn.Module): ...@@ -842,8 +840,8 @@ class RandomPerspective(torch.nn.Module):
endpoints = [topleft, topright, botright, botleft] endpoints = [topleft, topright, botright, botleft]
return startpoints, endpoints return startpoints, endpoints
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomResizedCrop(torch.nn.Module): class RandomResizedCrop(torch.nn.Module):
...@@ -954,7 +952,7 @@ class RandomResizedCrop(torch.nn.Module): ...@@ -954,7 +952,7 @@ class RandomResizedCrop(torch.nn.Module):
i, j, h, w = self.get_params(img, self.scale, self.ratio) i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F.resized_crop(img, i, j, h, w, self.size, self.interpolation)
def __repr__(self): def __repr__(self) -> str:
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + f"(size={self.size}" format_string = self.__class__.__name__ + f"(size={self.size}"
format_string += f", scale={tuple(round(s, 4) for s in self.scale)}" format_string += f", scale={tuple(round(s, 4) for s in self.scale)}"
...@@ -1006,8 +1004,8 @@ class FiveCrop(torch.nn.Module): ...@@ -1006,8 +1004,8 @@ class FiveCrop(torch.nn.Module):
""" """
return F.five_crop(img, self.size) return F.five_crop(img, self.size)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(size={self.size})" return f"{self.__class__.__name__}(size={self.size})"
class TenCrop(torch.nn.Module): class TenCrop(torch.nn.Module):
...@@ -1056,8 +1054,8 @@ class TenCrop(torch.nn.Module): ...@@ -1056,8 +1054,8 @@ class TenCrop(torch.nn.Module):
""" """
return F.ten_crop(img, self.size, self.vertical_flip) return F.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(size={self.size}, vertical_flip={self.vertical_flip})" return f"{self.__class__.__name__}(size={self.size}, vertical_flip={self.vertical_flip})"
class LinearTransformation(torch.nn.Module): class LinearTransformation(torch.nn.Module):
...@@ -1130,11 +1128,13 @@ class LinearTransformation(torch.nn.Module): ...@@ -1130,11 +1128,13 @@ class LinearTransformation(torch.nn.Module):
tensor = transformed_tensor.view(shape) tensor = transformed_tensor.view(shape)
return tensor return tensor
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(transformation_matrix=" s = (
format_string += str(self.transformation_matrix.tolist()) + ")" f"{self.__class__.__name__}(transformation_matrix="
format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")" f"{self.transformation_matrix.tolist()}"
return format_string f", mean_vector={self.mean_vector.tolist()})"
)
return s
class ColorJitter(torch.nn.Module): class ColorJitter(torch.nn.Module):
...@@ -1242,13 +1242,15 @@ class ColorJitter(torch.nn.Module): ...@@ -1242,13 +1242,15 @@ class ColorJitter(torch.nn.Module):
return img return img
def __repr__(self): def __repr__(self) -> str:
format_string = self.__class__.__name__ + "(" s = (
format_string += f"brightness={self.brightness}" f"{self.__class__.__name__}("
format_string += f", contrast={self.contrast}" f"brightness={self.brightness}"
format_string += f", saturation={self.saturation}" f", contrast={self.contrast}"
format_string += f", hue={self.hue})" f", saturation={self.saturation}"
return format_string f", hue={self.hue})"
)
return s
class RandomRotation(torch.nn.Module): class RandomRotation(torch.nn.Module):
...@@ -1346,7 +1348,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1346,7 +1348,7 @@ class RandomRotation(torch.nn.Module):
return F.rotate(img, angle, self.resample, self.expand, self.center, fill) return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
def __repr__(self): def __repr__(self) -> str:
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + f"(degrees={self.degrees}" format_string = self.__class__.__name__ + f"(degrees={self.degrees}"
format_string += f", interpolation={interpolate_str}" format_string += f", interpolation={interpolate_str}"
...@@ -1529,24 +1531,17 @@ class RandomAffine(torch.nn.Module): ...@@ -1529,24 +1531,17 @@ class RandomAffine(torch.nn.Module):
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
def __repr__(self): def __repr__(self) -> str:
s = "{name}(degrees={degrees}" s = f"{self.__class__.__name__}(degrees={self.degrees}"
if self.translate is not None: s += f", translate={self.translate}" if self.translate is not None else ""
s += ", translate={translate}" s += f", scale={self.scale}" if self.scale is not None else ""
if self.scale is not None: s += f", shear={self.shear}" if self.shear is not None else ""
s += ", scale={scale}" s += f", interpolation={self.interpolation.value}" if self.interpolation != InterpolationMode.NEAREST else ""
if self.shear is not None: s += f", fill={self.fill}" if self.fill != 0 else ""
s += ", shear={shear}" s += f", center={self.center}" if self.center is not None else ""
if self.interpolation != InterpolationMode.NEAREST:
s += ", interpolation={interpolation}"
if self.fill != 0:
s += ", fill={fill}"
if self.center is not None:
s += ", center={center}"
s += ")" s += ")"
d = dict(self.__dict__)
d["interpolation"] = self.interpolation.value return s
return s.format(name=self.__class__.__name__, **d)
class Grayscale(torch.nn.Module): class Grayscale(torch.nn.Module):
...@@ -1580,8 +1575,8 @@ class Grayscale(torch.nn.Module): ...@@ -1580,8 +1575,8 @@ class Grayscale(torch.nn.Module):
""" """
return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels)
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(num_output_channels={self.num_output_channels})" return f"{self.__class__.__name__}(num_output_channels={self.num_output_channels})"
class RandomGrayscale(torch.nn.Module): class RandomGrayscale(torch.nn.Module):
...@@ -1618,8 +1613,8 @@ class RandomGrayscale(torch.nn.Module): ...@@ -1618,8 +1613,8 @@ class RandomGrayscale(torch.nn.Module):
return F.rgb_to_grayscale(img, num_output_channels=num_output_channels) return F.rgb_to_grayscale(img, num_output_channels=num_output_channels)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomErasing(torch.nn.Module): class RandomErasing(torch.nn.Module):
...@@ -1748,13 +1743,16 @@ class RandomErasing(torch.nn.Module): ...@@ -1748,13 +1743,16 @@ class RandomErasing(torch.nn.Module):
return F.erase(img, x, y, h, w, v, self.inplace) return F.erase(img, x, y, h, w, v, self.inplace)
return img return img
def __repr__(self): def __repr__(self) -> str:
s = f"(p={self.p}, " s = (
s += f"scale={self.scale}, " f"{self.__class__.__name__}"
s += f"ratio={self.ratio}, " f"(p={self.p}, "
s += f"value={self.value}, " f"scale={self.scale}, "
s += f"inplace={self.inplace})" f"ratio={self.ratio}, "
return self.__class__.__name__ + s f"value={self.value}, "
f"inplace={self.inplace})"
)
return s
class GaussianBlur(torch.nn.Module): class GaussianBlur(torch.nn.Module):
...@@ -1818,10 +1816,9 @@ class GaussianBlur(torch.nn.Module): ...@@ -1818,10 +1816,9 @@ class GaussianBlur(torch.nn.Module):
sigma = self.get_params(self.sigma[0], self.sigma[1]) sigma = self.get_params(self.sigma[0], self.sigma[1])
return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) return F.gaussian_blur(img, self.kernel_size, [sigma, sigma])
def __repr__(self): def __repr__(self) -> str:
s = f"(kernel_size={self.kernel_size}, " s = f"{self.__class__.__name__}(kernel_size={self.kernel_size}, sigma={self.sigma})"
s += f"sigma={self.sigma})" return s
return self.__class__.__name__ + s
def _setup_size(size, error_msg): def _setup_size(size, error_msg):
...@@ -1883,8 +1880,8 @@ class RandomInvert(torch.nn.Module): ...@@ -1883,8 +1880,8 @@ class RandomInvert(torch.nn.Module):
return F.invert(img) return F.invert(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomPosterize(torch.nn.Module): class RandomPosterize(torch.nn.Module):
...@@ -1916,8 +1913,8 @@ class RandomPosterize(torch.nn.Module): ...@@ -1916,8 +1913,8 @@ class RandomPosterize(torch.nn.Module):
return F.posterize(img, self.bits) return F.posterize(img, self.bits)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(bits={self.bits},p={self.p})" return f"{self.__class__.__name__}(bits={self.bits},p={self.p})"
class RandomSolarize(torch.nn.Module): class RandomSolarize(torch.nn.Module):
...@@ -1949,8 +1946,8 @@ class RandomSolarize(torch.nn.Module): ...@@ -1949,8 +1946,8 @@ class RandomSolarize(torch.nn.Module):
return F.solarize(img, self.threshold) return F.solarize(img, self.threshold)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(threshold={self.threshold},p={self.p})" return f"{self.__class__.__name__}(threshold={self.threshold},p={self.p})"
class RandomAdjustSharpness(torch.nn.Module): class RandomAdjustSharpness(torch.nn.Module):
...@@ -1982,8 +1979,8 @@ class RandomAdjustSharpness(torch.nn.Module): ...@@ -1982,8 +1979,8 @@ class RandomAdjustSharpness(torch.nn.Module):
return F.adjust_sharpness(img, self.sharpness_factor) return F.adjust_sharpness(img, self.sharpness_factor)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(sharpness_factor={self.sharpness_factor},p={self.p})" return f"{self.__class__.__name__}(sharpness_factor={self.sharpness_factor},p={self.p})"
class RandomAutocontrast(torch.nn.Module): class RandomAutocontrast(torch.nn.Module):
...@@ -2013,8 +2010,8 @@ class RandomAutocontrast(torch.nn.Module): ...@@ -2013,8 +2010,8 @@ class RandomAutocontrast(torch.nn.Module):
return F.autocontrast(img) return F.autocontrast(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
class RandomEqualize(torch.nn.Module): class RandomEqualize(torch.nn.Module):
...@@ -2044,5 +2041,5 @@ class RandomEqualize(torch.nn.Module): ...@@ -2044,5 +2041,5 @@ class RandomEqualize(torch.nn.Module):
return F.equalize(img) return F.equalize(img)
return img return img
def __repr__(self): def __repr__(self) -> str:
return self.__class__.__name__ + f"(p={self.p})" return f"{self.__class__.__name__}(p={self.p})"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment