Unverified Commit c4deb7b3 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Feature Extractor accepts `segmentation_maps` (#15964)



* feature extractor accepts

* resolved conversations

* added examples in test for ADE20K

* num_classes -> num_labels

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* resolving conversations

* resolving conversations

* removed ADE

* CI

* minor changes in conversion script

* reduce_labels in feature extractor

* minor changes

* correct preprocess for instace segmentation maps

* minor changes

* minor changes

* CI

* debugging

* better padding

* going to update labels inside the model

* going to update labels inside the model

* minor changes

* tests

* removed changes in feature_extractor_utils

* conversation

* conversation

* example in feature extractor

* more docstring in modeling

* test

* make style

* doc
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c2f8eaf6
...@@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter: ...@@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor: def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
model = original_config.MODEL model = original_config.MODEL
model_input = original_config.INPUT model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
return MaskFormerFeatureExtractor( return MaskFormerFeatureExtractor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(), image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(), image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST, size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST, max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
size_divisibility=32, # 32 is required by swin size_divisibility=32, # 32 is required by swin
) )
...@@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter: ...@@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter:
yield config, checkpoint yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation): def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor):
with torch.no_grad(): with torch.no_grad():
original_model = original_model.eval() original_model = original_model.eval()
...@@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation): ...@@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation):
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x) our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
feature_extractor = MaskFormerFeatureExtractor()
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384)) our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose( assert torch.allclose(
...@@ -707,7 +708,7 @@ if __name__ == "__main__": ...@@ -707,7 +708,7 @@ if __name__ == "__main__":
mask_former_for_instance_segmentation mask_former_for_instance_segmentation
) )
test(original_model, mask_former_for_instance_segmentation) test(original_model, mask_former_for_instance_segmentation, feature_extractor)
model_name = get_name(checkpoint_file) model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}") logger.info(f"🪄 Saving {model_name}")
......
...@@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput): ...@@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query. query.
masks_queries_logits (`torch.FloatTensor`): masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class. query. Note the `+ 1` is needed because we incorporate the null class.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone). Last hidden states (final feature map) of the last stage of the encoder model (backbone).
...@@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: ...@@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
""" """
inputs = inputs.sigmoid().flatten(1) inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels) numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1) loss = 1 - (numerator + 1) / (denominator + 1)
return loss return loss
...@@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module): ...@@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
...@@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module):
Params: Params:
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_classes` with the A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits. classification logits.
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the A tensor` of dim `batch_size, num_queries, height, width` with the
...@@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module):
indices: List[Tuple[np.array]] = [] indices: List[Tuple[np.array]] = []
preds_masks = masks_queries_logits preds_masks = masks_queries_logits
preds_probs = class_queries_logits.softmax(dim=-1) preds_probs = class_queries_logits
# downsample all masks in one go -> save memory
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
# iterate through batch size # iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels): for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
# downsample the target mask, save memory
target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL, # Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class]. # but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted. # The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels] cost_class = -pred_probs[:, labels]
# flatten spatial dimension "q h w -> q (h w)" # flatten spatial dimension "q h w -> q (h w)"
num_queries, height, width = pred_mask.shape pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
# same for target_mask "c h w -> c (h w)" # same for target_mask "c h w -> c (h w)"
num_channels, height, width = target_mask.shape target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W] # compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat) cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES] # Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat) cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
# final cost matrix # final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
...@@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module): ...@@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module):
class MaskFormerLoss(nn.Module): class MaskFormerLoss(nn.Module):
def __init__( def __init__(
self, self,
num_classes: int, num_labels: int,
matcher: MaskFormerHungarianMatcher, matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float], weight_dict: Dict[str, float],
eos_coef: float, eos_coef: float,
...@@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module): ...@@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module):
matched ground-truth / prediction (supervise class and mask) matched ground-truth / prediction (supervise class and mask)
Args: Args:
num_classes (`int`): num_labels (`int`):
The number of classes. The number of classes.
matcher (`MaskFormerHungarianMatcher`): matcher (`MaskFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels. A torch module that computes the assigments between the predictions and labels.
...@@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module): ...@@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module):
super().__init__() super().__init__()
requires_backends(self, ["scipy"]) requires_backends(self, ["scipy"])
self.num_classes = num_classes self.num_labels = num_labels
self.matcher = matcher self.matcher = matcher
self.weight_dict = weight_dict self.weight_dict = weight_dict
self.eos_coef = eos_coef self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1) empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight) self.register_buffer("empty_weight", empty_weight)
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels( def loss_labels(
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array] self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy. """Compute the losses related to the labels using cross entropy.
Args: Args:
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes` A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`Dict[str, Tensor]`): class_labels (`List[torch.Tensor]`):
A tensor of shape `batch_size, num_classes` List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`: indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher. The indices computed by the Hungarian matcher.
...@@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module): ...@@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module):
batch_size, num_queries, _ = pred_logits.shape batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight) criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices) idx = self._get_predictions_permutation_indices(indices)
# shape = [BATCH, N_QUERIES] # shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)]) target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = [BATCH, N_QUERIES] # shape = (batch_size, num_queries)
target_classes = torch.full( target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
) )
target_classes[idx] = target_classes_o target_classes[idx] = target_classes_o
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q" # target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
pred_logits_permuted = pred_logits.permute(0, 2, 1) pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_permuted, target_classes) loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce} losses = {"loss_cross_entropy": loss_ce}
return losses return losses
def loss_masks( def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss. """Compute the losses related to the masks using focal and dice loss.
...@@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module): ...@@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`): mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`: indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher. The indices computed by the Hungarian matcher.
num_masks (`int)`: num_masks (`int)`:
...@@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module): ...@@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module):
""" """
src_idx = self._get_predictions_permutation_indices(indices) src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices) tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W] # shape (batch_size * num_queries, height, width)
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W] pred_masks = masks_queries_logits[src_idx]
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W] # shape (batch_size, num_queries, height, width)
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W] # pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size, we have to add one dim to use interpolate # upsample predictions to the target size, we have to add one dim to use interpolate
pred_masks = nn.functional.interpolate( pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
...@@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module): ...@@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module):
pred_masks = pred_masks[:, 0].flatten(1) pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1) target_masks = target_masks.flatten(1)
target_masks = target_masks.view(pred_masks.shape)
losses = { losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks), "loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks), "loss_dice": dice_loss(pred_masks, target_masks, num_masks),
...@@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module): ...@@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module):
target_indices = torch.cat([tgt for (_, tgt) in indices]) target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices return batch_indices, target_indices
def get_loss(self, loss, outputs, labels, indices, num_masks):
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
if loss not in loss_map:
raise KeyError(f"{loss} not in loss_map")
return loss_map[loss](outputs, labels, indices, num_masks)
def forward( def forward(
self, self,
masks_queries_logits: torch.Tensor, masks_queries_logits: Tensor,
class_queries_logits: torch.Tensor, class_queries_logits: Tensor,
mask_labels: torch.Tensor, mask_labels: List[Tensor],
class_labels: torch.Tensor, class_labels: List[Tensor],
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None, auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
) -> Dict[str, Tensor]: ) -> Dict[str, Tensor]:
""" """
This performs the loss computation. This performs the loss computation.
...@@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module): ...@@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`): masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width` A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`): class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes` A tensor of shape `batch_size, num_queries, num_labels`
mask_labels (`torch.Tensor`): mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes, height, width` List of mask labels of shape `(labels, height, width)`.
class_labels (`torch.Tensor`): class_labels (`List[torch.Tensor]`):
A tensor of shape `batch_size, num_classes` List of class labels of shape `(labels)`.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*): auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder. inner layers of the Detr's Decoder.
...@@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module): ...@@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module):
for each auxiliary predictions. for each auxiliary predictions.
""" """
# Retrieve the matching between the outputs of the last layer and the labels # retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
# Compute the average number of target masks accross all nodes, for normalization purposes num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device) # get all the losses
# Compute all the requested losses
losses: Dict[str, Tensor] = { losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices), **self.loss_labels(class_queries_logits, class_labels, indices),
} }
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None: if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions): for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"] masks_queries_logits = aux_outputs["masks_queries_logits"]
...@@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module): ...@@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module):
return losses return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor: def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
# Compute the average number of target masks accross all nodes, for normalization purposes """
num_masks = class_labels.shape[0] Computes the average number of target masks accross the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt return num_masks_pt
...@@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
loss_dict: Dict[str, Tensor] = self.criterion( loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
) )
# weight each loss by `self.weight_dict[<LOSS_NAME>]` # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
weighted_loss_dict: Dict[str, Tensor] = { for key, weight in self.weight_dict.items():
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict for loss_key, loss in loss_dict.items():
} if key in loss_key:
return weighted_loss_dict loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor: def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values()) return sum(loss_dict.values())
...@@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def forward( def forward(
self, self,
pixel_values: Tensor, pixel_values: Tensor,
mask_labels: Optional[Tensor] = None, mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[Tensor] = None, class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None, pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
...@@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput: ) -> MaskFormerForInstanceSegmentationOutput:
r""" r"""
mask_labels (`torch.FloatTensor`, *optional*): mask_labels (`List[torch.Tensor]`, *optional*):
The target mask of shape `(num_classes, height, width)`. List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`torch.LongTensor`, *optional*): class_labels (`List[torch.LongTensor]`, *optional*):
The target labels of shape `(num_classes)`. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns: Returns:
......
...@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
do_normalize=True, do_normalize=True,
image_mean=[0.5, 0.5, 0.5], image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=True,
ignore_index=255,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
self.num_classes = 2 self.num_classes = 2
self.height = 3 self.height = 3
self.width = 4 self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_feat_extract_dict(self): def prepare_feat_extract_dict(self):
return { return {
...@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase): ...@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
"image_mean": self.image_mean, "image_mean": self.image_mean,
"image_std": self.image_std, "image_std": self.image_std,
"size_divisibility": self.size_divisibility, "size_divisibility": self.size_divisibility,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
} }
def get_expected_values(self, image_inputs, batched=False): def get_expected_values(self, image_inputs, batched=False):
...@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "do_resize")) self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size")) self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size")) self.assertTrue(hasattr(feature_extractor, "max_size"))
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
self.assertTrue(hasattr(feature_extractor, "num_labels"))
def test_batch_feature(self): def test_batch_feature(self):
pass pass
...@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
def test_equivalence_pad_and_create_pixel_mask(self): def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors # Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False) feature_extractor_2 = self.feature_extraction_class(
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
)
# create random PyTorch tensors # create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs: for image in image_inputs:
...@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4) torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
) )
def comm_get_feature_extractor_inputs(self, with_annotations=False): def comm_get_feature_extractor_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict) feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# prepare image and target # prepare image and target
num_classes = 8
batch_size = self.feature_extract_tester.batch_size batch_size = self.feature_extract_tester.batch_size
num_labels = self.feature_extract_tester.num_labels
annotations = None annotations = None
instance_id_to_semantic_id = None
if with_annotations: if with_segmentation_maps:
annotations = [ high = num_labels
{ if is_instance_map:
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32), high * 2
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64), labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
} }
for _ in range(batch_size) annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
] if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False) image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
inputs = feature_extractor(
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True) image_inputs,
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs return inputs
def test_init_without_params(self):
pass
def test_with_size_divisibility(self): def test_with_size_divisibility(self):
size_divisibilities = [8, 16, 32] size_divisibilities = [8, 16, 32]
weird_input_sizes = [(407, 802), (582, 1094)] weird_input_sizes = [(407, 802), (582, 1094)]
...@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0) self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0) self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
def test_call_with_numpy_annotations(self): def test_call_with_segmentation_maps(self):
num_classes = 8 def common(is_instance_map=False, segmentation_type=None):
batch_size = self.feature_extract_tester.batch_size inputs = self.comm_get_feature_extractor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True) )
# check the batch_size
for el in inputs.values():
self.assertEqual(el.shape[0], batch_size)
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"] mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"] class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2]) common()
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1]) common(is_instance_map=True)
self.assertEqual(mask_labels.shape[1], class_labels.shape[1]) common(is_instance_map=False, segmentation_type="pil")
self.assertEqual(mask_labels.shape[1], num_classes) common(is_instance_map=True, segmentation_type="pil")
def test_post_process_segmentation(self): def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_segmentation(outputs) segmentation = fature_extractor.post_process_segmentation(outputs)
...@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
) )
def test_post_process_semantic_segmentation(self): def test_post_process_semantic_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_semantic_segmentation(outputs) segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
...@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest ...@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size)) self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
def test_post_process_panoptic_segmentation(self): def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class() fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs() outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0) segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
......
...@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
inputs = feature_extractor( inputs = feature_extractor(
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))], [np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
annotations=[ segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
return_tensors="pt", return_tensors="pt",
).to(torch_device) )
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment