Unverified Commit c4deb7b3 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Feature Extractor accepts `segmentation_maps` (#15964)



* feature extractor accepts

* resolved conversations

* added examples in test for ADE20K

* num_classes -> num_labels

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* resolving conversations

* resolving conversations

* removed ADE

* CI

* minor changes in conversion script

* reduce_labels in feature extractor

* minor changes

* correct preprocess for instace segmentation maps

* minor changes

* minor changes

* CI

* debugging

* better padding

* going to update labels inside the model

* going to update labels inside the model

* minor changes

* tests

* removed changes in feature_extractor_utils

* conversation

* conversation

* example in feature extractor

* more docstring in modeling

* test

* make style

* doc
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c2f8eaf6
...@@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter: ...@@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor: def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
model = original_config.MODEL model = original_config.MODEL
model_input = original_config.INPUT model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
return MaskFormerFeatureExtractor( return MaskFormerFeatureExtractor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST, size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST, max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
size_divisibility=32, # 32 is required by swin size_divisibility=32, # 32 is required by swin
) )
...@@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter: ...@@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter:
yield config, checkpoint yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation): def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor):
with torch.no_grad(): with torch.no_grad():
original_model = original_model.eval() original_model = original_model.eval()
...@@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation): ...@@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation):
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x) our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
feature_extractor = MaskFormerFeatureExtractor()
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384)) our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose( assert torch.allclose(
...@@ -707,7 +708,7 @@ if __name__ == "__main__": ...@@ -707,7 +708,7 @@ if __name__ == "__main__":
mask_former_for_instance_segmentation mask_former_for_instance_segmentation
) )
test(original_model, mask_former_for_instance_segmentation) test(original_model, mask_former_for_instance_segmentation, feature_extractor)
model_name = get_name(checkpoint_file) model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}") logger.info(f"🪄 Saving {model_name}")
......
...@@ -54,6 +54,10 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -54,6 +54,10 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
max_size (`int`, *optional*, defaults to 1333): max_size (`int`, *optional*, defaults to 1333):
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
set to `True`. set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`. Only has an effect
if `do_resize` is set to `True`.
size_divisibility (`int`, *optional*, defaults to 32): size_divisibility (`int`, *optional*, defaults to 32):
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
Swin Transformer. Swin Transformer.
...@@ -64,8 +68,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -64,8 +68,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std. ImageNet std.
ignore_index (`int`, *optional*, default to 255): ignore_index (`int`, *optional*):
Value of the index (label) to ignore. Value of the index (label) to be removed from the segmentation maps.
reduce_labels (`bool`, *optional*, defaults to `False`):
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The
background label will be replaced by `ignore_index`.
""" """
...@@ -76,24 +84,28 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -76,24 +84,28 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
do_resize=True, do_resize=True,
size=800, size=800,
max_size=1333, max_size=1333,
resample=Image.BILINEAR,
size_divisibility=32, size_divisibility=32,
do_normalize=True, do_normalize=True,
image_mean=None, image_mean=None,
image_std=None, image_std=None,
ignore_index=255, ignore_index=None,
reduce_labels=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.do_resize = do_resize self.do_resize = do_resize
self.size = size self.size = size
self.max_size = max_size self.max_size = max_size
self.resample = resample
self.size_divisibility = size_divisibility self.size_divisibility = size_divisibility
self.ignore_index = ignore_index
self.do_normalize = do_normalize self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
self.ignore_index = ignore_index
self.reduce_labels = reduce_labels
def _resize(self, image, size, target=None, max_size=None): def _resize_with_size_divisibility(self, image, size, target=None, max_size=None):
""" """
Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int, Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int,
smaller edge of the image will be matched to this number. smaller edge of the image will be matched to this number.
...@@ -138,30 +150,19 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -138,30 +150,19 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility
size = (width, height) size = (width, height)
rescaled_image = self.resize(image, size=size) image = self.resize(image, size=size, resample=self.resample)
has_target = target is not None if target is not None:
target = self.resize(target, size=size, resample=Image.NEAREST)
if has_target: return image, target
target = target.copy()
# store original_size
target["original_size"] = image.size
if "masks" in target:
masks = torch.from_numpy(target["masks"])[:, None].float()
# use PyTorch as current workaround
# TODO replace by self.resize
interpolated_masks = (
nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5
).float()
target["masks"] = interpolated_masks.numpy()
return rescaled_image, target
def __call__( def __call__(
self, self,
images: ImageInput, images: ImageInput,
annotations: Union[List[Dict], List[List[Dict]]] = None, segmentation_maps: ImageInput = None,
pad_and_return_pixel_mask: Optional[bool] = True, pad_and_return_pixel_mask: Optional[bool] = True,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
...@@ -170,6 +171,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -170,6 +171,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are
real/which are padding. real/which are padding.
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
<Tip warning={true}> <Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
...@@ -183,10 +190,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -183,10 +190,8 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width. number of channels, H and W are image height and width.
annotations (`Dict`, `List[Dict]`, *optional*): segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys: Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations.
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask. Whether or not to pad images up to the largest image in a batch and create a pixel mask.
...@@ -196,7 +201,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -196,7 +201,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**). - 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*): instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
dictionary mapping instance ids to label ids to create a semantic segmentation map.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects. objects.
...@@ -206,15 +216,16 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -206,15 +216,16 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- **pixel_values** -- Pixel values to be fed to a model. - **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`). *"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
model (when `annotations` are provided). (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
""" """
# Input type checking for clearer error # Input type checking for clearer error
valid_images = False valid_images = False
valid_annotations = False valid_segmentation_maps = False
# Check that images has a valid type # Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images): if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
...@@ -228,6 +239,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -228,6 +239,23 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), " "Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)." "`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
) )
# Check that segmentation maps has a valid type
if segmentation_maps is not None:
if isinstance(segmentation_maps, (Image.Image, np.ndarray)) or is_torch_tensor(segmentation_maps):
valid_segmentation_maps = True
elif isinstance(segmentation_maps, (list, tuple)):
if (
len(segmentation_maps) == 0
or isinstance(segmentation_maps[0], (Image.Image, np.ndarray))
or is_torch_tensor(segmentation_maps[0])
):
valid_segmentation_maps = True
if not valid_segmentation_maps:
raise ValueError(
"Segmentation maps must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool( is_batched = bool(
isinstance(images, (list, tuple)) isinstance(images, (list, tuple))
...@@ -236,35 +264,33 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -236,35 +264,33 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
if not is_batched: if not is_batched:
images = [images] images = [images]
if annotations is not None: if segmentation_maps is not None:
annotations = [annotations] segmentation_maps = [segmentation_maps]
# Check that annotations has a valid type
if annotations is not None:
valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0]
if not valid_annotations:
raise ValueError(
"Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)."
"The annotations must be numpy arrays in the following format:"
"{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}"
)
# transformations (resizing + normalization) # transformations (resizing + normalization)
if self.do_resize and self.size is not None: if self.do_resize and self.size is not None:
if annotations is not None: if segmentation_maps is not None:
for idx, (image, target) in enumerate(zip(images, annotations)): for idx, (image, target) in enumerate(zip(images, segmentation_maps)):
image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size) image, target = self._resize_with_size_divisibility(
image=image, target=target, size=self.size, max_size=self.max_size
)
images[idx] = image images[idx] = image
annotations[idx] = target segmentation_maps[idx] = target
else: else:
for idx, image in enumerate(images): for idx, image in enumerate(images):
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0] images[idx] = self._resize_with_size_divisibility(
image=image, target=None, size=self.size, max_size=self.max_size
)[0]
if self.do_normalize: if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images] images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# NOTE I will be always forced to pad them them since they have to be stacked in the batch dim # NOTE I will be always forced to pad them them since they have to be stacked in the batch dim
encoded_inputs = self.encode_inputs( encoded_inputs = self.encode_inputs(
images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors images,
segmentation_maps,
pad_and_return_pixel_mask,
instance_id_to_semantic_id=instance_id_to_semantic_id,
return_tensors=return_tensors,
) )
# Convert to TensorType # Convert to TensorType
...@@ -287,25 +313,57 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -287,25 +313,57 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
maxes[index] = max(maxes[index], item) maxes[index] = max(maxes[index], item)
return maxes return maxes
def convert_segmentation_map_to_binary_masks(
self,
segmentation_map: "np.ndarray",
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
):
if self.reduce_labels:
if self.ignore_index is None:
raise ValueError("`ignore_index` must be set when `reduce_labels` is `True`.")
segmentation_map[segmentation_map == 0] = self.ignore_index
# instances ids start from 1!
segmentation_map -= 1
segmentation_map[segmentation_map == self.ignore_index - 1] = self.ignore_index
if instance_id_to_semantic_id is not None:
# segmentation_map will be treated as an instance segmentation map where each pixel is a instance id
# thus it has to be converted to a semantic segmentation map
for instance_id, label_id in instance_id_to_semantic_id.items():
segmentation_map[segmentation_map == instance_id] = label_id
# get all the labels in the image
labels = np.unique(segmentation_map)
# remove ignore index (if we have one)
if self.ignore_index is not None:
labels = labels[labels != self.ignore_index]
# helping broadcast by making mask [1,W,H] and labels [C, 1, 1]
binary_masks = segmentation_map[None] == labels[:, None, None]
return binary_masks.astype(np.float32), labels.astype(np.int64)
def encode_inputs( def encode_inputs(
self, self,
pixel_values_list: List["torch.Tensor"], pixel_values_list: List["np.ndarray"],
annotations: Optional[List[Dict]] = None, segmentation_maps: ImageInput = None,
pad_and_return_pixel_mask: Optional[bool] = True, pad_and_return_pixel_mask: bool = True,
instance_id_to_semantic_id: Optional[Dict[int, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
): ):
""" """
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
MaskFormer addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps
will be converted to lists of binary masks and their respective labels. Let's see an example, assuming
`segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels =
[[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for
each mask.
Args: Args:
pixel_values_list (`List[torch.Tensor]`): pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`. width)`.
annotations (`Dict`, `List[Dict]`, *optional*): segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys: The corresponding semantic segmentation maps with the pixel-wise annotations.
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`): pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask. Whether or not to pad images up to the largest image in a batch and create a pixel mask.
...@@ -315,7 +373,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -315,7 +373,12 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- 1 for pixels that are real (i.e. **not masked**), - 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**). - 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~utils.TensorType`], *optional*): instance_id_to_semantic_id (`Dict[int, int]`, *optional*):
If passed, we treat `segmentation_maps` as an instance segmentation map where each pixel represents an
instance id. To convert it to a binary mask of shape (`batch, num_labels, height, width`) we need a
dictionary mapping instance ids to label ids to create a semantic segmentation map.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor` If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects. objects.
...@@ -325,13 +388,29 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -325,13 +388,29 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
- **pixel_values** -- Pixel values to be fed to a model. - **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if - **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`). *"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model
model (when `annotations` are provided). (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when
`annotations` are provided). `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of
`mask_labels[i][j]` if `class_labels[i][j]`.
""" """
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list]) max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
annotations = None
if segmentation_maps is not None:
segmentation_maps = map(np.array, segmentation_maps)
converted_segmentation_maps = []
for segmentation_map in segmentation_maps:
converted_segmentation_map = self.convert_segmentation_map_to_binary_masks(
segmentation_map, instance_id_to_semantic_id
)
converted_segmentation_maps.append(converted_segmentation_map)
annotations = []
for mask, classes in converted_segmentation_maps:
annotations.append({"masks": mask, "classes": classes})
channels, height, width = max_size channels, height, width = max_size
pixel_values = [] pixel_values = []
pixel_mask = [] pixel_mask = []
...@@ -339,7 +418,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -339,7 +418,6 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
class_labels = [] class_labels = []
for idx, image in enumerate(pixel_values_list): for idx, image in enumerate(pixel_values_list):
# create padded image # create padded image
if pad_and_return_pixel_mask:
padded_image = np.zeros((channels, height, width), dtype=np.float32) padded_image = np.zeros((channels, height, width), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image) padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
image = padded_image image = padded_image
...@@ -348,13 +426,13 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -348,13 +426,13 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
if annotations: if annotations:
annotation = annotations[idx] annotation = annotations[idx]
masks = annotation["masks"] masks = annotation["masks"]
if pad_and_return_pixel_mask: # pad mask with `ignore_index`
padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype) masks = np.pad(
padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks) masks,
masks = padded_masks ((0, 0), (0, height - masks.shape[1]), (0, width - masks.shape[2])),
mask_labels.append(masks) constant_values=self.ignore_index,
class_labels.append(annotation["labels"]) )
if pad_and_return_pixel_mask: annotation["masks"] = masks
# create pixel mask # create pixel mask
mask = np.zeros((height, width), dtype=np.int64) mask = np.zeros((height, width), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True mask[: image.shape[1], : image.shape[2]] = True
...@@ -362,12 +440,15 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM ...@@ -362,12 +440,15 @@ class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionM
# return as BatchFeature # return as BatchFeature
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask} data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
# we cannot batch them since they don't share a common class size
if annotations: if annotations:
data["mask_labels"] = mask_labels for label in annotations:
data["class_labels"] = class_labels mask_labels.append(torch.from_numpy(label["masks"]))
class_labels.append(torch.from_numpy(label["classes"]))
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs["mask_labels"] = mask_labels
encoded_inputs["class_labels"] = class_labels
return encoded_inputs return encoded_inputs
......
...@@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): ...@@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query. query.
masks_queries_logits (`torch.FloatTensor`): masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class. query. Note the `+ 1` is needed because we incorporate the null class.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone). Last hidden states (final feature map) of the last stage of the encoder model (backbone).
...@@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: ...@@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
""" """
inputs = inputs.sigmoid().flatten(1) inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1) loss = 1 - (numerator + 1) / (denominator + 1)
return loss return loss
...@@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module): ...@@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
...@@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module):
Params: Params:
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_classes` with the A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits. classification logits.
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the A tensor` of dim `batch_size, num_queries, height, width` with the
...@@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module):
indices: List[Tuple[np.array]] = [] indices: List[Tuple[np.array]] = []
preds_masks = masks_queries_logits preds_masks = masks_queries_logits
preds_probs = class_queries_logits.softmax(dim=-1) preds_probs = class_queries_logits
# downsample all masks in one go -> save memory
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
# iterate through batch size # iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
# downsample the target mask, save memory
target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL, # Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class]. # but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted. # The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels] cost_class = -pred_probs[:, labels]
# flatten spatial dimension "q h w -> q (h w)" # flatten spatial dimension "q h w -> q (h w)"
num_queries, height, width = pred_mask.shape pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
# same for target_mask "c h w -> c (h w)" # same for target_mask "c h w -> c (h w)"
num_channels, height, width = target_mask.shape target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W] # compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat) cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES] # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat) cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
# final cost matrix # final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
...@@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module):
class MaskFormerLoss(nn.Module): class MaskFormerLoss(nn.Module):
def __init__( def __init__(
self, self,
num_classes: int, num_labels: int,
matcher: MaskFormerHungarianMatcher, matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float], weight_dict: Dict[str, float],
eos_coef: float, eos_coef: float,
...@@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module): ...@@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module):
matched ground-truth / prediction (supervise class and mask) matched ground-truth / prediction (supervise class and mask)
Args: Args:
num_classes (`int`): num_labels (`int`):
The number of classes. The number of classes.
matcher (`MaskFormerHungarianMatcher`): matcher (`MaskFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels. A torch module that computes the assigments between the predictions and labels.
...@@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module): ...@@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module):
super().__init__() super().__init__()
requires_backends(self, ["scipy"]) requires_backends(self, ["scipy"])
self.num_classes = num_classes self.num_labels = num_labels
self.matcher = matcher self.matcher = matcher
self.weight_dict = weight_dict self.weight_dict = weight_dict
self.eos_coef = eos_coef self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1) empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight) self.register_buffer("empty_weight", empty_weight)
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels( def loss_labels(
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array] self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy. """Compute the losses related to the labels using cross entropy.
Args: Args:
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes` A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`Dict[str, Tensor]`): class_labels (`List[torch.Tensor]`):
A tensor of shape `batch_size, num_classes` List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`: indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher. The indices computed by the Hungarian matcher.
...@@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module): ...@@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module):
batch_size, num_queries, _ = pred_logits.shape batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight) criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices) idx = self._get_predictions_permutation_indices(indices)
# shape = [BATCH, N_QUERIES] # shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = [BATCH, N_QUERIES] # shape = (batch_size, num_queries)
target_classes = torch.full( target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q" # target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
pred_logits_permuted = pred_logits.permute(0, 2, 1) pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_permuted, target_classes) loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce} losses = {"loss_cross_entropy": loss_ce}
return losses return losses
def loss_masks( def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss. """Compute the losses related to the masks using focal and dice loss.
...@@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module): ...@@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`): mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`: indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher. The indices computed by the Hungarian matcher.
num_masks (`int)`: num_masks (`int)`:
...@@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module): ...@@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module):
""" """
src_idx = self._get_predictions_permutation_indices(indices) src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W] # shape (batch_size * num_queries, height, width)
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W] pred_masks = masks_queries_logits[src_idx]
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W] # shape (batch_size, num_queries, height, width)
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W] # pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size, we have to add one dim to use interpolate # upsample predictions to the target size, we have to add one dim to use interpolate
pred_masks = nn.functional.interpolate( pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
...@@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module): ...@@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module):
pred_masks = pred_masks[:, 0].flatten(1) pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1) target_masks = target_masks.flatten(1)
target_masks = target_masks.view(pred_masks.shape)
losses = { losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks), "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks), "loss_dice": dice_loss(pred_masks, target_masks, num_masks),
...@@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module): ...@@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module):
target_indices = torch.cat([tgt for (_, tgt) in indices]) target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices return batch_indices, target_indices
def get_loss(self, loss, outputs, labels, indices, num_masks):
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
if loss not in loss_map:
raise KeyError(f"{loss} not in loss_map")
return loss_map[loss](outputs, labels, indices, num_masks)
def forward( def forward(
self, self,
masks_queries_logits: torch.Tensor, masks_queries_logits: Tensor,
class_queries_logits: torch.Tensor, class_queries_logits: Tensor,
mask_labels: torch.Tensor, mask_labels: List[Tensor],
class_labels: torch.Tensor, class_labels: List[Tensor],
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
""" """
This performs the loss computation. This performs the loss computation.
...@@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module): ...@@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes` A tensor of shape `batch_size, num_queries, num_labels`
mask_labels (`torch.Tensor`): mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes, height, width` List of mask labels of shape `(labels, height, width)`.
class_labels (`torch.Tensor`): class_labels (`List[torch.Tensor]`):
A tensor of shape `batch_size, num_classes` List of class labels of shape `(labels)`.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder. inner layers of the Detr's Decoder.
...@@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module): ...@@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module):
for each auxiliary predictions. for each auxiliary predictions.
""" """
# Retrieve the matching between the outputs of the last layer and the labels # retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
# Compute the average number of target masks accross all nodes, for normalization purposes num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device) # get all the losses
# Compute all the requested losses
losses: Dict[str, Tensor] = { losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices), **self.loss_labels(class_queries_logits, class_labels, indices),
} }
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None: if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions): for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"] masks_queries_logits = aux_outputs["masks_queries_logits"]
...@@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module): ...@@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module):
return losses return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
# Compute the average number of target masks accross all nodes, for normalization purposes """
num_masks = class_labels.shape[0] Computes the average number of target masks accross the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt return num_masks_pt
...@@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
loss_dict: Dict[str, Tensor] = self.criterion( loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
) )
# weight each loss by `self.weight_dict[<LOSS_NAME>]` # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
weighted_loss_dict: Dict[str, Tensor] = { for key, weight in self.weight_dict.items():
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict for loss_key, loss in loss_dict.items():
} if key in loss_key:
return weighted_loss_dict loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values()) return sum(loss_dict.values())
...@@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def forward( def forward(
self, self,
pixel_values: Tensor, pixel_values: Tensor,
mask_labels: Optional[Tensor] = None, mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[Tensor] = None, class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None, pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput: ) -> MaskFormerForInstanceSegmentationOutput:
r""" r"""
mask_labels (`torch.FloatTensor`, *optional*): mask_labels (`List[torch.Tensor]`, *optional*):
The target mask of shape `(num_classes, height, width)`. List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`torch.LongTensor`, *optional*): class_labels (`List[torch.LongTensor]`, *optional*):
The target labels of shape `(num_classes)`. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns: Returns:
......
...@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=True,
ignore_index=255,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
self.num_classes = 2 self.num_classes = 2
self.height = 3 self.height = 3
self.width = 4 self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_feat_extract_dict(self): def prepare_feat_extract_dict(self):
return { return {
...@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
"image_mean": self.image_mean, "image_mean": self.image_mean,
"image_std": self.image_std, "image_std": self.image_std,
"size_divisibility": self.size_divisibility, "size_divisibility": self.size_divisibility,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
} }
def get_expected_values(self, image_inputs, batched=False): def get_expected_values(self, image_inputs, batched=False):
...@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size")) self.assertTrue(hasattr(feature_extractor, "max_size"))
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
self.assertTrue(hasattr(feature_extractor, "num_labels"))
def test_batch_feature(self): def test_batch_feature(self):
pass pass
...@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
def test_equivalence_pad_and_create_pixel_mask(self): def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors # Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False) feature_extractor_2 = self.feature_extraction_class(
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs: for image in image_inputs:
...@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
) )
def comm_get_feature_extractor_inputs(self, with_annotations=False): def comm_get_feature_extractor_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# prepare image and target # prepare image and target
num_classes = 8
batch_size = self.feature_extract_tester.batch_size batch_size = self.feature_extract_tester.batch_size
num_labels = self.feature_extract_tester.num_labels
annotations = None annotations = None
instance_id_to_semantic_id = None
if with_annotations: if with_segmentation_maps:
annotations = [ high = num_labels
{ if is_instance_map:
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32), high * 2
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64), labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
} }
for _ in range(batch_size) annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
] if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
inputs = feature_extractor(
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True) image_inputs,
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs return inputs
def test_init_without_params(self):
pass
def test_with_size_divisibility(self): def test_with_size_divisibility(self):
size_divisibilities = [8, 16, 32] size_divisibilities = [8, 16, 32]
weird_input_sizes = [(407, 802), (582, 1094)] weird_input_sizes = [(407, 802), (582, 1094)]
...@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0) self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0) self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
def test_call_with_numpy_annotations(self): def test_call_with_segmentation_maps(self):
num_classes = 8 def common(is_instance_map=False, segmentation_type=None):
batch_size = self.feature_extract_tester.batch_size inputs = self.comm_get_feature_extractor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True) )
# check the batch_size
for el in inputs.values():
self.assertEqual(el.shape[0], batch_size)
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"] mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"] class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2]) common()
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1]) common(is_instance_map=True)
self.assertEqual(mask_labels.shape[1], class_labels.shape[1]) common(is_instance_map=False, segmentation_type="pil")
self.assertEqual(mask_labels.shape[1], num_classes) common(is_instance_map=True, segmentation_type="pil")
def test_post_process_segmentation(self): def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_segmentation(outputs) segmentation = fature_extractor.post_process_segmentation(outputs)
...@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
) )
def test_post_process_semantic_segmentation(self): def test_post_process_semantic_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_semantic_segmentation(outputs) segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
...@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size)) self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
def test_post_process_panoptic_segmentation(self): def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0) segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
......
...@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
inputs = feature_extractor( inputs = feature_extractor(
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))], [np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
annotations=[ segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
return_tensors="pt", return_tensors="pt",
).to(torch_device) )
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment