"docs/source/en/custom_models.md" did not exist on "32f5de10a01e2489cb0295d752f76ad81b20c5cb"
Unverified Commit c4deb7b3 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Feature Extractor accepts `segmentation_maps` (#15964)



* feature extractor accepts

* resolved conversations

* added examples in test for ADE20K

* num_classes -> num_labels

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* resolving conversations

* resolving conversations

* removed ADE

* CI

* minor changes in conversion script

* reduce_labels in feature extractor

* minor changes

* correct preprocess for instace segmentation maps

* minor changes

* minor changes

* CI

* debugging

* better padding

* going to update labels inside the model

* going to update labels inside the model

* minor changes

* tests

* removed changes in feature_extractor_utils

* conversation

* conversation

* example in feature extractor

* more docstring in modeling

* test

* make style

* doc
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c2f8eaf6
......@@ -169,12 +169,15 @@ class OriginalMaskFormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
model = original_config.MODEL
model_input = original_config.INPUT
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
return MaskFormerFeatureExtractor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
ignore_index=dataset_catalog.ignore_label,
size_divisibility=32, # 32 is required by swin
)
......@@ -552,7 +555,7 @@ class OriginalMaskFormerCheckpointToOursConverter:
yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation):
def test(original_model, our_model: MaskFormerForInstanceSegmentation, feature_extractor: MaskFormerFeatureExtractor):
with torch.no_grad():
original_model = original_model.eval()
......@@ -600,8 +603,6 @@ def test(original_model, our_model: MaskFormerForInstanceSegmentation):
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
feature_extractor = MaskFormerFeatureExtractor()
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose(
......@@ -707,7 +708,7 @@ if __name__ == "__main__":
mask_former_for_instance_segmentation
)
test(original_model, mask_former_for_instance_segmentation)
test(original_model, mask_former_for_instance_segmentation, feature_extractor)
model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}")
......
......@@ -269,7 +269,7 @@ class MaskFormerForInstanceSegmentationOutput(ModelOutput):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
......@@ -424,7 +424,7 @@ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
......@@ -918,7 +918,9 @@ class MaskFormerSwinBlock(nn.Module):
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C
shifted_windows = window_reverse(
attention_windows, self.window_size, height_pad, width_pad
) # B height' width' C
# reverse cyclic shift
if self.shift_size > 0:
......@@ -1621,7 +1623,7 @@ class MaskFormerHungarianMatcher(nn.Module):
Params:
masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_classes` with the
A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits.
class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the
......@@ -1644,24 +1646,23 @@ class MaskFormerHungarianMatcher(nn.Module):
indices: List[Tuple[np.array]] = []
preds_masks = masks_queries_logits
preds_probs = class_queries_logits.softmax(dim=-1)
# downsample all masks in one go -> save memory
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
preds_probs = class_queries_logits
# iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
# downsample the target mask, save memory
target_mask = nn.functional.interpolate(target_mask[:, None], size=pred_mask.shape[-2:], mode="nearest")
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels]
# flatten spatial dimension "q h w -> q (h w)"
num_queries, height, width = pred_mask.shape
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
pred_mask_flat = pred_mask.flatten(1) # [num_queries, height*width]
# same for target_mask "c h w -> c (h w)"
num_channels, height, width = target_mask.shape
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W]
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
target_mask_flat = target_mask[:, 0].flatten(1) # [num_total_labels, height*width]
# compute the focal loss between each mask pairs -> shape (num_queries, num_labels)
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES]
# Compute the dice loss betwen each mask pairs -> shape (num_queries, num_labels)
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
......@@ -1691,7 +1692,7 @@ class MaskFormerHungarianMatcher(nn.Module):
class MaskFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
num_labels: int,
matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
......@@ -1702,7 +1703,7 @@ class MaskFormerLoss(nn.Module):
matched ground-truth / prediction (supervise class and mask)
Args:
num_classes (`int`):
num_labels (`int`):
The number of classes.
matcher (`MaskFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
......@@ -1714,24 +1715,50 @@ class MaskFormerLoss(nn.Module):
super().__init__()
requires_backends(self, ["scipy"])
self.num_classes = num_classes
self.num_labels = num_labels
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_labels(
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array]
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
class_labels (`Dict[str, Tensor]`):
A tensor of shape `batch_size, num_classes`
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
......@@ -1744,21 +1771,21 @@ class MaskFormerLoss(nn.Module):
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
# shape = [BATCH, N_QUERIES]
# shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = [BATCH, N_QUERIES]
# shape = (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
(batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q"
pred_logits_permuted = pred_logits.permute(0, 2, 1)
loss_ce = criterion(pred_logits_permuted, target_classes)
# target_classes is a (batch_size, num_labels, num_queries), we need to permute pred_logits "b q c -> b c q"
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int
self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
......@@ -1766,7 +1793,7 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
......@@ -1780,10 +1807,12 @@ class MaskFormerLoss(nn.Module):
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W]
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W]
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W]
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W]
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
# upsample predictions to the target size, we have to add one dim to use interpolate
pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
......@@ -1791,7 +1820,6 @@ class MaskFormerLoss(nn.Module):
pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(pred_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks),
......@@ -1810,19 +1838,13 @@ class MaskFormerLoss(nn.Module):
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def get_loss(self, loss, outputs, labels, indices, num_masks):
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
if loss not in loss_map:
raise KeyError(f"{loss} not in loss_map")
return loss_map[loss](outputs, labels, indices, num_masks)
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
) -> Dict[str, Tensor]:
"""
This performs the loss computation.
......@@ -1831,11 +1853,11 @@ class MaskFormerLoss(nn.Module):
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
A tensor of shape `batch_size, num_queries, num_labels`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes, height, width`
class_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes`
List of mask labels of shape `(labels, height, width)`.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
......@@ -1850,19 +1872,16 @@ class MaskFormerLoss(nn.Module):
for each auxiliary predictions.
"""
# Retrieve the matching between the outputs of the last layer and the labels
# retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device)
# Compute all the requested losses
# compute the average number of target masks for normalization purposes
num_masks: Number = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
......@@ -1874,8 +1893,10 @@ class MaskFormerLoss(nn.Module):
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks = class_labels.shape[0]
"""
Computes the average number of target masks accross the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
......@@ -2380,11 +2401,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
)
# weight each loss by `self.weight_dict[<LOSS_NAME>]`
weighted_loss_dict: Dict[str, Tensor] = {
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict
}
return weighted_loss_dict
# weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
for key, weight in self.weight_dict.items():
for loss_key, loss in loss_dict.items():
if key in loss_key:
loss *= weight
return loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
......@@ -2425,8 +2448,8 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[Tensor] = None,
class_labels: Optional[Tensor] = None,
mask_labels: Optional[List[Tensor]] = None,
class_labels: Optional[List[Tensor]] = None,
pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
......@@ -2434,10 +2457,11 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput:
r"""
mask_labels (`torch.FloatTensor`, *optional*):
The target mask of shape `(num_classes, height, width)`.
class_labels (`torch.LongTensor`, *optional*):
The target labels of shape `(num_classes)`.
mask_labels (`List[torch.Tensor]`, *optional*):
List of mask labels of shape `(num_labels, height, width)` to be fed to a model
class_labels (`List[torch.LongTensor]`, *optional*):
list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
Returns:
......
......@@ -49,6 +49,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
num_labels=10,
reduce_labels=True,
ignore_index=255,
):
self.parent = parent
self.batch_size = batch_size
......@@ -68,6 +71,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
self.num_classes = 2
self.height = 3
self.width = 4
self.num_labels = num_labels
self.reduce_labels = reduce_labels
self.ignore_index = ignore_index
def prepare_feat_extract_dict(self):
return {
......@@ -78,6 +84,9 @@ class MaskFormerFeatureExtractionTester(unittest.TestCase):
"image_mean": self.image_mean,
"image_std": self.image_std,
"size_divisibility": self.size_divisibility,
"num_labels": self.num_labels,
"reduce_labels": self.reduce_labels,
"ignore_index": self.ignore_index,
}
def get_expected_values(self, image_inputs, batched=False):
......@@ -140,6 +149,8 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size"))
self.assertTrue(hasattr(feature_extractor, "ignore_index"))
self.assertTrue(hasattr(feature_extractor, "num_labels"))
def test_batch_feature(self):
pass
......@@ -245,7 +256,9 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
feature_extractor_2 = self.feature_extraction_class(
do_resize=False, do_normalize=False, num_labels=self.feature_extract_tester.num_classes
)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
......@@ -262,28 +275,41 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
)
def comm_get_feature_extractor_inputs(self, with_annotations=False):
def comm_get_feature_extractor_inputs(
self, with_segmentation_maps=False, is_instance_map=False, segmentation_type="np"
):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# prepare image and target
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
num_labels = self.feature_extract_tester.num_labels
annotations = None
if with_annotations:
annotations = [
{
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32),
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64),
instance_id_to_semantic_id = None
if with_segmentation_maps:
high = num_labels
if is_instance_map:
high * 2
labels_expanded = list(range(num_labels)) * 2
instance_id_to_semantic_id = {
instance_id: label_id for instance_id, label_id in enumerate(labels_expanded)
}
for _ in range(batch_size)
]
annotations = [np.random.randint(0, high, (384, 384)).astype(np.uint8) for _ in range(batch_size)]
if segmentation_type == "pil":
annotations = [Image.fromarray(annotation) for annotation in annotations]
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True)
inputs = feature_extractor(
image_inputs,
annotations,
return_tensors="pt",
instance_id_to_semantic_id=instance_id_to_semantic_id,
pad_and_return_pixel_mask=True,
)
return inputs
def test_init_without_params(self):
pass
def test_with_size_divisibility(self):
size_divisibilities = [8, 16, 32]
weird_input_sizes = [(407, 802), (582, 1094)]
......@@ -297,27 +323,29 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
def test_call_with_numpy_annotations(self):
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True)
# check the batch_size
for el in inputs.values():
self.assertEqual(el.shape[0], batch_size)
def test_call_with_segmentation_maps(self):
def common(is_instance_map=False, segmentation_type=None):
inputs = self.comm_get_feature_extractor_inputs(
with_segmentation_maps=True, is_instance_map=is_instance_map, segmentation_type=segmentation_type
)
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
pixel_values = inputs["pixel_values"]
# check the batch_size
for mask_label, class_label in zip(mask_labels, class_labels):
self.assertEqual(mask_label.shape[0], class_label.shape[0])
# this ensure padding has happened
self.assertEqual(mask_label.shape[1:], pixel_values.shape[2:])
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2])
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
self.assertEqual(mask_labels.shape[1], num_classes)
common()
common(is_instance_map=True)
common(is_instance_map=False, segmentation_type="pil")
common(is_instance_map=True, segmentation_type="pil")
def test_post_process_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_segmentation(outputs)
......@@ -340,7 +368,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
)
def test_post_process_semantic_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_semantic_segmentation(outputs)
......@@ -361,7 +389,7 @@ class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest
self.assertEqual(segmentation.shape, (self.feature_extract_tester.batch_size, *target_size))
def test_post_process_panoptic_segmentation(self):
fature_extractor = self.feature_extraction_class()
fature_extractor = self.feature_extraction_class(num_labels=self.feature_extract_tester.num_classes)
outputs = self.feature_extract_tester.get_fake_maskformer_outputs()
segmentation = fature_extractor.post_process_panoptic_segmentation(outputs, object_mask_threshold=0)
......
......@@ -397,18 +397,19 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self):
def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
[np.zeros((3, 800, 1333)), np.zeros((3, 800, 1333))],
annotations=[
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
segmentation_maps=[np.zeros((384, 384)).astype(np.float32), np.zeros((384, 384)).astype(np.float32)],
return_tensors="pt",
).to(torch_device)
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch_device)
inputs["mask_labels"] = [el.to(torch_device) for el in inputs["mask_labels"]]
inputs["class_labels"] = [el.to(torch_device) for el in inputs["class_labels"]]
with torch.no_grad():
outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment