Unverified Commit ee26e9c2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Add preprocessing information on the Weights documentation (#6009)

* Adding `__repr__` in presets

* Adds `describe()` methods to all presets.

* Adding transform descriptions in the documentation.

* Change "preprocessing" to "inference"
parent c67a5839
...@@ -366,6 +366,11 @@ def inject_weight_metadata(app, what, name, obj, options, lines): ...@@ -366,6 +366,11 @@ def inject_weight_metadata(app, what, name, obj, options, lines):
lines += [".. table::", ""] lines += [".. table::", ""]
lines += textwrap.indent(table, " " * 4).split("\n") lines += textwrap.indent(table, " " * 4).split("\n")
lines.append("") lines.append("")
lines.append(
f"The inference transforms are available at ``{str(field)}.transforms`` and "
f"perform the following operations: {field.transforms().describe()}"
)
lines.append("")
def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None): def generate_weights_table(module, table_name, metrics, include_patterns=None, exclude_patterns=None):
......
...@@ -25,6 +25,12 @@ class ObjectDetection(nn.Module): ...@@ -25,6 +25,12 @@ class ObjectDetection(nn.Module):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, torch.float) return F.convert_image_dtype(img, torch.float)
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return "The images are rescaled to ``[0.0, 1.0]``."
class ImageClassification(nn.Module): class ImageClassification(nn.Module):
def __init__( def __init__(
...@@ -37,21 +43,38 @@ class ImageClassification(nn.Module): ...@@ -37,21 +43,38 @@ class ImageClassification(nn.Module):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self._crop_size = [crop_size] self.crop_size = [crop_size]
self._size = [resize_size] self.resize_size = [resize_size]
self._mean = list(mean) self.mean = list(mean)
self._std = list(std) self.std = list(std)
self._interpolation = interpolation self.interpolation = interpolation
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
img = F.resize(img, self._size, interpolation=self._interpolation) img = F.resize(img, self.resize_size, interpolation=self.interpolation)
img = F.center_crop(img, self._crop_size) img = F.center_crop(img, self.crop_size)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self._mean, std=self._std) img = F.normalize(img, mean=self.mean, std=self.std)
return img return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
class VideoClassification(nn.Module): class VideoClassification(nn.Module):
def __init__( def __init__(
...@@ -64,11 +87,11 @@ class VideoClassification(nn.Module): ...@@ -64,11 +87,11 @@ class VideoClassification(nn.Module):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self._crop_size = list(crop_size) self.crop_size = list(crop_size)
self._size = list(resize_size) self.resize_size = list(resize_size)
self._mean = list(mean) self.mean = list(mean)
self._std = list(std) self.std = list(std)
self._interpolation = interpolation self.interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor: def forward(self, vid: Tensor) -> Tensor:
need_squeeze = False need_squeeze = False
...@@ -79,11 +102,11 @@ class VideoClassification(nn.Module): ...@@ -79,11 +102,11 @@ class VideoClassification(nn.Module):
vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W) vid = vid.permute(0, 1, 4, 2, 3) # (N, T, H, W, C) => (N, T, C, H, W)
N, T, C, H, W = vid.shape N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W) vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self._size, interpolation=self._interpolation) vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
vid = F.center_crop(vid, self._crop_size) vid = F.center_crop(vid, self.crop_size)
vid = F.convert_image_dtype(vid, torch.float) vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self._mean, std=self._std) vid = F.normalize(vid, mean=self.mean, std=self.std)
H, W = self._crop_size H, W = self.crop_size
vid = vid.view(N, T, C, H, W) vid = vid.view(N, T, C, H, W)
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W) vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
...@@ -91,6 +114,23 @@ class VideoClassification(nn.Module): ...@@ -91,6 +114,23 @@ class VideoClassification(nn.Module):
vid = vid.squeeze(dim=0) vid = vid.squeeze(dim=0)
return vid return vid
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
f"The video frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Then the values are rescaled to "
f"``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
class SemanticSegmentation(nn.Module): class SemanticSegmentation(nn.Module):
def __init__( def __init__(
...@@ -102,20 +142,35 @@ class SemanticSegmentation(nn.Module): ...@@ -102,20 +142,35 @@ class SemanticSegmentation(nn.Module):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None: ) -> None:
super().__init__() super().__init__()
self._size = [resize_size] if resize_size is not None else None self.resize_size = [resize_size] if resize_size is not None else None
self._mean = list(mean) self.mean = list(mean)
self._std = list(std) self.std = list(std)
self._interpolation = interpolation self.interpolation = interpolation
def forward(self, img: Tensor) -> Tensor: def forward(self, img: Tensor) -> Tensor:
if isinstance(self._size, list): if isinstance(self.resize_size, list):
img = F.resize(img, self._size, interpolation=self._interpolation) img = F.resize(img, self.resize_size, interpolation=self.interpolation)
if not isinstance(img, Tensor): if not isinstance(img, Tensor):
img = F.pil_to_tensor(img) img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float) img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self._mean, std=self._std) img = F.normalize(img, mean=self.mean, std=self.std)
return img return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Then the values are rescaled to ``[0.0, 1.0]`` and normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
class OpticalFlow(nn.Module): class OpticalFlow(nn.Module):
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
...@@ -135,3 +190,9 @@ class OpticalFlow(nn.Module): ...@@ -135,3 +190,9 @@ class OpticalFlow(nn.Module):
img2 = img2.contiguous() img2 = img2.contiguous()
return img1, img2 return img1, img2
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return "The images are rescaled to ``[-1.0, 1.0]``."
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment