"examples/vscode:/vscode.git/clone" did not exist on "8aa67fc192c1485f499e4c2dcb22bf8ad245160b"
Unverified Commit ce2d4bc6 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

MaskFormer,Mask2former - reduce memory load (#25741)

Allocate result array ahead of time
parent 0daeeb40
...@@ -2011,13 +2011,12 @@ class Mask2FormerMaskPredictor(nn.Module): ...@@ -2011,13 +2011,12 @@ class Mask2FormerMaskPredictor(nn.Module):
def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None): def forward(self, outputs: torch.Tensor, pixel_embeddings: torch.Tensor, attention_mask_target_size: int = None):
mask_embeddings = self.mask_embedder(outputs.transpose(0, 1)) mask_embeddings = self.mask_embedder(outputs.transpose(0, 1))
# Sum up over the channels # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
# (batch_size, num_queries, num_channels, 1, 1) batch_size, num_queries, num_channels = mask_embeddings.shape
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1) _, _, height, width = pixel_embeddings.shape
# (batch_size, 1, num_channels, height, width) outputs_mask = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
pixel_embeddings = pixel_embeddings.unsqueeze(1) for c in range(num_channels):
# (batch_size, num_queries, height, width) outputs_mask += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
outputs_mask = (mask_embeddings * pixel_embeddings).sum(2)
attention_mask = nn.functional.interpolate( attention_mask = nn.functional.interpolate(
outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False outputs_mask, size=attention_mask_target_size, mode="bilinear", align_corners=False
......
...@@ -1789,13 +1789,15 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1789,13 +1789,15 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
class_queries_logits = classes[-1] class_queries_logits = classes[-1]
# get the masks # get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs) mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# sum up over the channels for each embedding
# (num_embeddings, batch_size, num_queries, num_channels, 1, 1) # Equivalent to einsum('lbqc, bchw -> lbqhw') but jit friendly
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1) num_embeddings, batch_size, num_queries, num_channels = mask_embeddings.shape
# (1, batch_size, 1, num_channels, height, width) _, _, height, width = pixel_embeddings.shape
pixel_embeddings = pixel_embeddings.unsqueeze(0).unsqueeze(2) binaries_masks = torch.zeros(
# (num_embeddings, batch_size, num_queries, height, width) (num_embeddings, batch_size, num_queries, height, width), device=mask_embeddings.device
binaries_masks = (mask_embeddings * pixel_embeddings).sum(dim=3) )
for c in range(num_channels):
binaries_masks += mask_embeddings[..., c][..., None, None] * pixel_embeddings[None, :, None, c]
masks_queries_logits = binaries_masks[-1] masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used # go til [:-1] because the last one is always used
...@@ -1811,12 +1813,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel): ...@@ -1811,12 +1813,13 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
# get the masks # get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states) mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels # sum up over the channels
# (batch_size, num_queries, num_channels, 1, 1)
mask_embeddings = mask_embeddings.unsqueeze(-1).unsqueeze(-1) # Equivalent to einsum('bqc, bchw -> bqhw') but jit friendly
# (batch_size, 1, num_channels, height, width) batch_size, num_queries, num_channels = mask_embeddings.shape
pixel_embeddings = pixel_embeddings.unsqueeze(1) _, _, height, width = pixel_embeddings.shape
# (batch_size, num_queries, height, width) masks_queries_logits = torch.zeros((batch_size, num_queries, height, width), device=mask_embeddings.device)
masks_queries_logits = (mask_embeddings * pixel_embeddings).sum(dim=2) for c in range(num_channels):
masks_queries_logits += mask_embeddings[..., c][..., None, None] * pixel_embeddings[:, None, c]
return class_queries_logits, masks_queries_logits, auxiliary_logits return class_queries_logits, masks_queries_logits, auxiliary_logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment